Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +7 -0
- LICENSE +201 -0
- ORIGINAL_README.md +166 -0
- assets/images/teaser.jpg +0 -0
- assets/videos/apt_exp_1_all.gif +3 -0
- assets/videos/apt_exp_2_all.gif +3 -0
- assets/videos/baodao_exp_1_all.gif +3 -0
- assets/videos/exp_1.gif +3 -0
- assets/videos/exp_2.gif +3 -0
- assets/videos/gf_exp1.gif +3 -0
- assets/videos/gf_exp1.mp4 +3 -0
- demo.ipynb +0 -0
- demo.py +98 -0
- demo/demo.py +98 -0
- demo/requirements.txt +10 -0
- projects/glamm/datasets/__init__.py +7 -0
- projects/glamm/datasets/collate_fns/glamm_collate_fn.py +136 -0
- projects/glamm/datasets/gcg_dataset.py +349 -0
- projects/glamm/datasets/refcoco_segm_dataset.py +195 -0
- projects/glamm/datasets/region_level_dataset.py +297 -0
- projects/glamm/datasets/semantic_seg_dataset.py +424 -0
- projects/glamm/datasets/utils/ade20k_classes.json +30 -0
- projects/glamm/datasets/utils/cocostuff_classes.txt +183 -0
- projects/glamm/datasets/utils/utils.py +131 -0
- projects/glamm/models/glamm.py +183 -0
- projects/glamm/models/region_encoder.py +359 -0
- projects/glamm/utils.py +280 -0
- projects/llava_sam2/configs/sa2va_4b.py +548 -0
- projects/llava_sam2/datasets/ChatUniVi_Dataset.py +389 -0
- projects/llava_sam2/datasets/GCG_Dataset.py +375 -0
- projects/llava_sam2/datasets/Grand_Dataset.py +241 -0
- projects/llava_sam2/datasets/MeVIS_Dataset.py +5 -0
- projects/llava_sam2/datasets/Osprey_Dataset.py +463 -0
- projects/llava_sam2/datasets/ReSAM2_Dataset.py +489 -0
- projects/llava_sam2/datasets/ReVOS_Dataset.py +602 -0
- projects/llava_sam2/datasets/RefCOCO_Dataset.py +338 -0
- projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py +47 -0
- projects/llava_sam2/datasets/__init__.py +15 -0
- projects/llava_sam2/datasets/collect_fns.py +206 -0
- projects/llava_sam2/datasets/encode_fn.py +144 -0
- projects/llava_sam2/datasets/gcg_process.py +297 -0
- projects/llava_sam2/datasets/grand_process.py +110 -0
- projects/llava_sam2/datasets/utils.py +58 -0
- projects/llava_sam2/datasets/vqa_dataset.py +509 -0
- projects/llava_sam2/deepspeed_zero2_sam2.json +24 -0
- projects/llava_sam2/gradio/app.py +151 -0
- projects/llava_sam2/gradio/app_utils.py +293 -0
- projects/llava_sam2/models/__init__.py +3 -0
- projects/llava_sam2/models/extension/__init__.py +1 -0
- projects/llava_sam2/models/extension/sam2_base.py +281 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            assets/videos/apt_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            assets/videos/apt_exp_2_all.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            assets/videos/baodao_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            assets/videos/exp_1.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            assets/videos/exp_2.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            assets/videos/gf_exp1.gif filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            assets/videos/gf_exp1.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright [yyyy] [name of copyright owner]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        ORIGINAL_README.md
    ADDED
    
    | @@ -0,0 +1,166 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            [\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va)  [\[📜 arXiv\]](https://arxiv.org/abs/2501.04001) [\[🤗 HuggingFace\]](https://huggingface.co/collections/ByteDance/sa2va-model-zoo-677e3084d71b5f108d00e093) [\[🎥 Introduction\]]() [\[🧑💻 GitHub\]](https://github.com/magic-research/Sa2VA) [\[Online Demo (Sa2VA-4B)\]](https://5512470799b6b35fbc.gradio.live/)
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            [**Haobo Yuan**](https://yuanhaobo.me/)<sup>1*</sup> · [**Xiangtai Li**](https://scholar.google.com/citations?user=NmHgX-wAAAAJ)<sup>2*†</sup> · [**Tao Zhang**](https://zhang-tao-whu.github.io/)<sup>2,3*</sup> · [**Zilong Huang**](http://speedinghzl.github.io/)<sup>2</sup> · [**Shilin Xu**](https://xushilin1.github.io/)<sup>4</sup> ·[**Shunping Ji**](https://scholar.google.com/citations?user=FjoRmF4AAAAJ&hl=en)<sup>3</sup> ·[**Yunhai Tong**](https://scholar.google.com/citations?user=T4gqdPkAAAAJ&hl=zh-CN)<sup>4</sup> ·
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            [**Lu Qi**](https://luqi.info/)<sup>2</sup> · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/)<sup>2</sup> · [**Ming-Hsuan Yang**](https://faculty.ucmerced.edu/mhyang/)<sup>1</sup>
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            <sup>1</sup>UC Merced    <sup>2</sup>ByteDance Seed    <sup>3</sup>WHU    <sup>4</sup>PKU
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            † project lead * the first three authors equally contribute to the work.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            ## Overiew
         | 
| 17 | 
            +
            This repository contains the code for the paper "Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos".
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            Sa2VA is the the first unified model for dense grounded understanding of both images and videos. Unlike existing multi-modal large language models, which are often limited to specific modalities and tasks, Sa2VA supports a wide range of image and video tasks, including referring segmentation and conversation, with minimal one-shot instruction tuning. Sa2VA combines SAM-2, a foundation video segmentation model, with LLaVA, an advanced vision-language model, and unifies text, image, and video into a shared LLM token space.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ## Model Zoo
         | 
| 22 | 
            +
            We provide the following models:
         | 
| 23 | 
            +
            | Model Name |                             Base MLLM                             |                                 Language Part                                 |                       HF Link                        |
         | 
| 24 | 
            +
            |:----------:|:-----------------------------------------------------------------:|:-----------------------------------------------------------------------------:|:----------------------------------------------------:|
         | 
| 25 | 
            +
            |  Sa2VA-1B  | [InternVL2.0-1B](https://huggingface.co/OpenGVLab/InternVL2-1B) |   [Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct)    | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-1B) |
         | 
| 26 | 
            +
            |  Sa2VA-4B  | [InternVL2.5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) |    [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct)     | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-4B) |
         | 
| 27 | 
            +
            |  Sa2VA-8B  | [InternVL2.5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) |  [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat)   | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-8B) |
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ## Gradio Demos
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            We provide a script that implements interactive chat using gradio, which requires installing `gradio==4.42.0`. You can try it to quickly build a chat interface locally.
         | 
| 32 | 
            +
            ```shell
         | 
| 33 | 
            +
            PYTHONPATH=. python projects/llava_sam2/gradio/app.py ByteDance/Sa2VA-4B
         | 
| 34 | 
            +
            ```
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            ## Quick Start
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            Our Sa2VA model is available on 🤗HuggingFace. With very few steps, you can try it with your own data. You can install the `demo/requirements.txt` to avoid training-only packages.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            **Option1 - scripts:**
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            Supposing you have a folder (`PATH_TO_FOLDER`) that contains images of a video, you can use the following script to chat with the Sa2VA model or segment the objects in the videos.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            ```bash
         | 
| 46 | 
            +
            > cd scripts
         | 
| 47 | 
            +
            > python demo.py PATH_TO_FOLDER --model_path ByteDance/Sa2VA-8B --work-dir OUTPUT_DIR --text "<image>Please describe the video content."
         | 
| 48 | 
            +
            ```
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            If the output contains the segmentation results, the results will be saved to `OUTPUT_DIR`.
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            **Option2 - Jupter Notebook:**
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            Please refer to `demo.ipynb`.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            ## Demo
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            <details open>
         | 
| 59 | 
            +
            <summary>Demo 1</summary>
         | 
| 60 | 
            +
            Input Video (Source: La La Land 2016):
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            Instruction: "Please segment the girl wearing the yellow dress."
         | 
| 65 | 
            +
            </details>
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            <details open>
         | 
| 68 | 
            +
            <summary>Demo 2</summary>
         | 
| 69 | 
            +
            Input Video (Source: La La Land 2016):
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            Instruction: "Please segment the main character."
         | 
| 74 | 
            +
            </details>
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            <details open>
         | 
| 78 | 
            +
            <summary>Demo 3</summary>
         | 
| 79 | 
            +
            Input Video (Source: Internet):
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            Instruction: "Please segment the person wearing sun glasses."
         | 
| 84 | 
            +
            </details>
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            <details open>
         | 
| 88 | 
            +
            <summary>Demo 4</summary>
         | 
| 89 | 
            +
            Input Video (Source: Internet):
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            Instruction: "Instruction: "Please segment the singing girl."
         | 
| 94 | 
            +
            </details>
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            <details open>
         | 
| 97 | 
            +
            <summary>Demo 5</summary>
         | 
| 98 | 
            +
            Input Video:
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            Instruction: "What is the atmosphere of the scene?"
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            Answer: "The scene has a dark and mysterious atmosphere, with the men dressed in suits and ties, and the dimly lit room."
         | 
| 105 | 
            +
            </details>
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            ## Training
         | 
| 109 | 
            +
            <details open>
         | 
| 110 | 
            +
            <summary>Installation</summary>
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            1. Please install the python and pytorch first:
         | 
| 113 | 
            +
            ```bash
         | 
| 114 | 
            +
            > conda create -n vlm python=3.10
         | 
| 115 | 
            +
            > conda activate vlm
         | 
| 116 | 
            +
            > conda install pytorch==2.3.1 torchvision==0.18.1 pytorch-cuda=12.1 cuda -c pytorch  -c "nvidia/label/cuda-12.1.0" -c "nvidia/label/cuda-12.1.1"
         | 
| 117 | 
            +
            ```
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            2. Install mmcv:
         | 
| 120 | 
            +
            ```bash
         | 
| 121 | 
            +
            > pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
         | 
| 122 | 
            +
            ```
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            3. Install other dependencies:
         | 
| 125 | 
            +
            ```bash
         | 
| 126 | 
            +
            > pip install -r requirements.txt
         | 
| 127 | 
            +
            ```
         | 
| 128 | 
            +
            </details>
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            <details open>
         | 
| 131 | 
            +
            <summary>Pretrained Model Preparation</summary>
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            You are expected to download the following pretrained models and place them in the `./pretrained` directory:
         | 
| 134 | 
            +
            - [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large)
         | 
| 135 | 
            +
            - [InternVL2_5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            </details>
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            <details open>
         | 
| 140 | 
            +
            <summary>Data Preparation</summary>
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            (TODO) Please download the training datasets and place them in the `data` directory. The download link is [here](https://huggingface.co/datasets/Dense-World/Sa2VA-Training).
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            </details>
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            <details open>
         | 
| 148 | 
            +
            <summary>Training Script</summary>
         | 
| 149 | 
            +
             | 
| 150 | 
            +
            Please run the following script to train:
         | 
| 151 | 
            +
            ```bash
         | 
| 152 | 
            +
            > bash tools/dist.sh train projects/llava_sam2/configs/sa2va_4b.py 8
         | 
| 153 | 
            +
            ```
         | 
| 154 | 
            +
            </details>
         | 
| 155 | 
            +
             | 
| 156 | 
            +
             | 
| 157 | 
            +
            ## References
         | 
| 158 | 
            +
            If you find this repository useful, please consider referring the following paper:
         | 
| 159 | 
            +
            ```
         | 
| 160 | 
            +
            @article{sa2va,
         | 
| 161 | 
            +
              title={Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos},
         | 
| 162 | 
            +
              author={Yuan, Haobo and Li, Xiangtai and Zhang, Tao and Huang, Zilong and Xu, Shilin and Ji, Shunping and Tong, Yunhai and Qi, Lu and Feng, Jiashi and Yang, Ming-Hsuan},
         | 
| 163 | 
            +
              journal={arXiv},
         | 
| 164 | 
            +
              year={2025}
         | 
| 165 | 
            +
            }
         | 
| 166 | 
            +
            ```
         | 
    	
        assets/images/teaser.jpg
    ADDED
    
    |   | 
    	
        assets/videos/apt_exp_1_all.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/apt_exp_2_all.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/baodao_exp_1_all.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/exp_1.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/exp_2.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/gf_exp1.gif
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/videos/gf_exp1.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:272f4246fbb62aa690811e01d5f8aecaac3d157cc01a9859de79675ee5d4f7cf
         | 
| 3 | 
            +
            size 15332128
         | 
    	
        demo.ipynb
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        demo.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            try:
         | 
| 9 | 
            +
                from mmengine.visualization import Visualizer
         | 
| 10 | 
            +
            except ImportError:
         | 
| 11 | 
            +
                Visualizer = None
         | 
| 12 | 
            +
                print("Warning: mmengine is not installed, visualization is disabled.")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def parse_args():
         | 
| 16 | 
            +
                parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
         | 
| 17 | 
            +
                parser.add_argument('image_folder', help='Path to image file')
         | 
| 18 | 
            +
                parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
         | 
| 19 | 
            +
                parser.add_argument('--work-dir', default=None, help='The dir to save results.')
         | 
| 20 | 
            +
                parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
         | 
| 21 | 
            +
                parser.add_argument('--select', type=int, default=-1)
         | 
| 22 | 
            +
                args = parser.parse_args()
         | 
| 23 | 
            +
                return args
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def visualize(pred_mask, image_path, work_dir):
         | 
| 27 | 
            +
                visualizer = Visualizer()
         | 
| 28 | 
            +
                img = cv2.imread(image_path)
         | 
| 29 | 
            +
                visualizer.set_image(img)
         | 
| 30 | 
            +
                visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
         | 
| 31 | 
            +
                visual_result = visualizer.get_image()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                output_path = os.path.join(work_dir, os.path.basename(image_path))
         | 
| 34 | 
            +
                cv2.imwrite(output_path, visual_result)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            if __name__ == "__main__":
         | 
| 37 | 
            +
                cfg = parse_args()
         | 
| 38 | 
            +
                model_path = cfg.model_path
         | 
| 39 | 
            +
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 40 | 
            +
                    model_path,
         | 
| 41 | 
            +
                    torch_dtype="auto",
         | 
| 42 | 
            +
                    device_map="auto",
         | 
| 43 | 
            +
                    trust_remote_code=True
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(
         | 
| 47 | 
            +
                    model_path,
         | 
| 48 | 
            +
                    trust_remote_code=True
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                image_files = []
         | 
| 52 | 
            +
                image_paths = []
         | 
| 53 | 
            +
                image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
         | 
| 54 | 
            +
                for filename in sorted(list(os.listdir(cfg.image_folder))):
         | 
| 55 | 
            +
                    if os.path.splitext(filename)[1].lower() in image_extensions:
         | 
| 56 | 
            +
                        image_files.append(filename)
         | 
| 57 | 
            +
                        image_paths.append(os.path.join(cfg.image_folder, filename))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                vid_frames = []
         | 
| 60 | 
            +
                for img_path in image_paths:
         | 
| 61 | 
            +
                    img = Image.open(img_path).convert('RGB')
         | 
| 62 | 
            +
                    vid_frames.append(img)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
                if cfg.select > 0:
         | 
| 66 | 
            +
                    img_frame = vid_frames[cfg.select - 1]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    print(f"Selected frame {cfg.select}")
         | 
| 69 | 
            +
                    print(f"The input is:\n{cfg.text}")
         | 
| 70 | 
            +
                    result = model.predict_forward(
         | 
| 71 | 
            +
                        image=img_frame,
         | 
| 72 | 
            +
                        text=cfg.text,
         | 
| 73 | 
            +
                        tokenizer=tokenizer,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                else:
         | 
| 76 | 
            +
                    print(f"The input is:\n{cfg.text}")
         | 
| 77 | 
            +
                    result = model.predict_forward(
         | 
| 78 | 
            +
                        video=vid_frames,
         | 
| 79 | 
            +
                        text=cfg.text,
         | 
| 80 | 
            +
                        tokenizer=tokenizer,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                prediction = result['prediction']
         | 
| 84 | 
            +
                print(f"The output is:\n{prediction}")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if '[SEG]' in prediction and Visualizer is not None:
         | 
| 87 | 
            +
                    _seg_idx = 0
         | 
| 88 | 
            +
                    pred_masks = result['prediction_masks'][_seg_idx]
         | 
| 89 | 
            +
                    for frame_idx in range(len(vid_frames)):
         | 
| 90 | 
            +
                        pred_mask = pred_masks[frame_idx]
         | 
| 91 | 
            +
                        if cfg.work_dir:
         | 
| 92 | 
            +
                            os.makedirs(cfg.work_dir, exist_ok=True)
         | 
| 93 | 
            +
                            visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
         | 
| 94 | 
            +
                        else:
         | 
| 95 | 
            +
                            os.makedirs('./temp_visualize_results', exist_ok=True)
         | 
| 96 | 
            +
                            visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    pass
         | 
    	
        demo/demo.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            try:
         | 
| 9 | 
            +
                from mmengine.visualization import Visualizer
         | 
| 10 | 
            +
            except ImportError:
         | 
| 11 | 
            +
                Visualizer = None
         | 
| 12 | 
            +
                print("Warning: mmengine is not installed, visualization is disabled.")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def parse_args():
         | 
| 16 | 
            +
                parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
         | 
| 17 | 
            +
                parser.add_argument('image_folder', help='Path to image file')
         | 
| 18 | 
            +
                parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
         | 
| 19 | 
            +
                parser.add_argument('--work-dir', default=None, help='The dir to save results.')
         | 
| 20 | 
            +
                parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
         | 
| 21 | 
            +
                parser.add_argument('--select', type=int, default=-1)
         | 
| 22 | 
            +
                args = parser.parse_args()
         | 
| 23 | 
            +
                return args
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def visualize(pred_mask, image_path, work_dir):
         | 
| 27 | 
            +
                visualizer = Visualizer()
         | 
| 28 | 
            +
                img = cv2.imread(image_path)
         | 
| 29 | 
            +
                visualizer.set_image(img)
         | 
| 30 | 
            +
                visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
         | 
| 31 | 
            +
                visual_result = visualizer.get_image()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                output_path = os.path.join(work_dir, os.path.basename(image_path))
         | 
| 34 | 
            +
                cv2.imwrite(output_path, visual_result)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            if __name__ == "__main__":
         | 
| 37 | 
            +
                cfg = parse_args()
         | 
| 38 | 
            +
                model_path = cfg.model_path
         | 
| 39 | 
            +
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 40 | 
            +
                    model_path,
         | 
| 41 | 
            +
                    torch_dtype="auto",
         | 
| 42 | 
            +
                    device_map="auto",
         | 
| 43 | 
            +
                    trust_remote_code=True
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(
         | 
| 47 | 
            +
                    model_path,
         | 
| 48 | 
            +
                    trust_remote_code=True
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                image_files = []
         | 
| 52 | 
            +
                image_paths = []
         | 
| 53 | 
            +
                image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
         | 
| 54 | 
            +
                for filename in sorted(list(os.listdir(cfg.image_folder))):
         | 
| 55 | 
            +
                    if os.path.splitext(filename)[1].lower() in image_extensions:
         | 
| 56 | 
            +
                        image_files.append(filename)
         | 
| 57 | 
            +
                        image_paths.append(os.path.join(cfg.image_folder, filename))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                vid_frames = []
         | 
| 60 | 
            +
                for img_path in image_paths:
         | 
| 61 | 
            +
                    img = Image.open(img_path).convert('RGB')
         | 
| 62 | 
            +
                    vid_frames.append(img)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
                if cfg.select > 0:
         | 
| 66 | 
            +
                    img_frame = vid_frames[cfg.select - 1]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    print(f"Selected frame {cfg.select}")
         | 
| 69 | 
            +
                    print(f"The input is:\n{cfg.text}")
         | 
| 70 | 
            +
                    result = model.predict_forward(
         | 
| 71 | 
            +
                        image=img_frame,
         | 
| 72 | 
            +
                        text=cfg.text,
         | 
| 73 | 
            +
                        tokenizer=tokenizer,
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                else:
         | 
| 76 | 
            +
                    print(f"The input is:\n{cfg.text}")
         | 
| 77 | 
            +
                    result = model.predict_forward(
         | 
| 78 | 
            +
                        video=vid_frames,
         | 
| 79 | 
            +
                        text=cfg.text,
         | 
| 80 | 
            +
                        tokenizer=tokenizer,
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                prediction = result['prediction']
         | 
| 84 | 
            +
                print(f"The output is:\n{prediction}")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                if '[SEG]' in prediction and Visualizer is not None:
         | 
| 87 | 
            +
                    _seg_idx = 0
         | 
| 88 | 
            +
                    pred_masks = result['prediction_masks'][_seg_idx]
         | 
| 89 | 
            +
                    for frame_idx in range(len(vid_frames)):
         | 
| 90 | 
            +
                        pred_mask = pred_masks[frame_idx]
         | 
| 91 | 
            +
                        if cfg.work_dir:
         | 
| 92 | 
            +
                            os.makedirs(cfg.work_dir, exist_ok=True)
         | 
| 93 | 
            +
                            visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
         | 
| 94 | 
            +
                        else:
         | 
| 95 | 
            +
                            os.makedirs('./temp_visualize_results', exist_ok=True)
         | 
| 96 | 
            +
                            visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    pass
         | 
    	
        demo/requirements.txt
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.3.1
         | 
| 2 | 
            +
            torchvision==0.18.1
         | 
| 3 | 
            +
            transformers==4.42.3
         | 
| 4 | 
            +
            opencv-python-headless<4.10
         | 
| 5 | 
            +
            peft<0.14.0
         | 
| 6 | 
            +
            timm==1.0.9
         | 
| 7 | 
            +
            einops==0.8.0
         | 
| 8 | 
            +
            flash_attn
         | 
| 9 | 
            +
            sentencepiece==0.2.0
         | 
| 10 | 
            +
            mmengine<1
         | 
    	
        projects/glamm/datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .semantic_seg_dataset import SemanticSegDataset, ADE20kSemanticSegDataset, \
         | 
| 2 | 
            +
                COCOStuffSemanticSegDataset, PascalPartSemanticSegDataset, PacoSemanticSegDataset
         | 
| 3 | 
            +
            from .gcg_dataset import GCGDataset, GranDfGCGDataset, RefCOCOgGCGDataset, OpenPsgGCGDataset, Flickr30kGCGDataset
         | 
| 4 | 
            +
            from .region_level_dataset import RefCocoGRegionDataset, VisualGenomeRegionDataset
         | 
| 5 | 
            +
            from .refcoco_segm_dataset import ReferSegmDataset
         | 
| 6 | 
            +
            from .utils.utils import *
         | 
| 7 | 
            +
            from .collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
    	
        projects/glamm/datasets/collate_fns/glamm_collate_fn.py
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Sequence
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch.nn.utils.rnn import pad_sequence
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
         | 
| 7 | 
            +
                                                  pad_for_sequence_parallel)
         | 
| 8 | 
            +
            from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def glamm_collate_fn(instances: Sequence[Dict],
         | 
| 12 | 
            +
                                   pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
         | 
| 13 | 
            +
                                   return_hf_format: bool = False,
         | 
| 14 | 
            +
                                   use_varlen_attn: bool = False):
         | 
| 15 | 
            +
                seq_parallel_world_size = get_sequence_parallel_world_size()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                input_ids, labels = [], []
         | 
| 18 | 
            +
                has_image = any(inst.get('pixel_values') is not None for inst in instances)
         | 
| 19 | 
            +
                has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
         | 
| 20 | 
            +
                has_mask = any(inst.get('masks') is not None for inst in instances)
         | 
| 21 | 
            +
                has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
         | 
| 22 | 
            +
                has_points = any(inst.get('points') is not None for inst in instances)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                if use_varlen_attn:
         | 
| 25 | 
            +
                    position_ids, cumulative_len = [], []
         | 
| 26 | 
            +
                    assert len(instances) == 1, (
         | 
| 27 | 
            +
                        f'If utilizing varlen attention, the batch size should be'
         | 
| 28 | 
            +
                        f' set to 1, but got {len(instances)}')
         | 
| 29 | 
            +
                    assert not has_image, 'Currently, it is not configured to '
         | 
| 30 | 
            +
                    'accommodate the use of varlen Attention in multimodal training'
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                if has_image:
         | 
| 33 | 
            +
                    pixel_values = []
         | 
| 34 | 
            +
                if has_grounding_image:
         | 
| 35 | 
            +
                    grounding_pixel_values = []
         | 
| 36 | 
            +
                if has_mask:
         | 
| 37 | 
            +
                    object_masks = []
         | 
| 38 | 
            +
                if has_bboxes:
         | 
| 39 | 
            +
                    object_bboxes = []
         | 
| 40 | 
            +
                if has_points:
         | 
| 41 | 
            +
                    prompt_points = []
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                for example in instances:
         | 
| 44 | 
            +
                    input_ids.append(torch.LongTensor(example['input_ids']))
         | 
| 45 | 
            +
                    labels.append(torch.LongTensor(example['labels']))
         | 
| 46 | 
            +
                    if use_varlen_attn:
         | 
| 47 | 
            +
                        cumulative_len.append(torch.IntTensor(example['cumulative_len']))
         | 
| 48 | 
            +
                        position_ids.append(torch.LongTensor(example['position_ids']))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    if has_image:
         | 
| 51 | 
            +
                        pixel_values.append(example['pixel_values'])
         | 
| 52 | 
            +
                    if has_grounding_image:
         | 
| 53 | 
            +
                        grounding_pixel_values.append(example['g_pixel_values'])
         | 
| 54 | 
            +
                    if has_mask:
         | 
| 55 | 
            +
                        if 'masks' in example.keys() and example['masks'] is not None:
         | 
| 56 | 
            +
                            object_masks.append(example['masks'])
         | 
| 57 | 
            +
                    if has_bboxes:
         | 
| 58 | 
            +
                        if 'bboxes' in example.keys() and example['bboxes'] is not None:
         | 
| 59 | 
            +
                            object_bboxes.append(example['bboxes'])
         | 
| 60 | 
            +
                    if has_points:
         | 
| 61 | 
            +
                        if 'points' in example.keys() and example['points'] is not None:
         | 
| 62 | 
            +
                            prompt_points.append(example['points'])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                ori_length = [len(ids) for ids in input_ids]
         | 
| 65 | 
            +
                if len(instances) > 1:
         | 
| 66 | 
            +
                    input_ids = pad_sequence(
         | 
| 67 | 
            +
                        input_ids, batch_first=True, padding_value=pad_index)
         | 
| 68 | 
            +
                    labels = pad_sequence(
         | 
| 69 | 
            +
                        labels, batch_first=True, padding_value=IGNORE_INDEX)
         | 
| 70 | 
            +
                else:
         | 
| 71 | 
            +
                    input_ids = torch.stack(input_ids)
         | 
| 72 | 
            +
                    labels = torch.stack(labels)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                if use_varlen_attn:
         | 
| 75 | 
            +
                    assert input_ids.size(1) % seq_parallel_world_size == 0
         | 
| 76 | 
            +
                    attention_mask = None
         | 
| 77 | 
            +
                    position_ids = torch.stack(position_ids, dim=0)
         | 
| 78 | 
            +
                else:
         | 
| 79 | 
            +
                    # Some tokenizers have the same eos token and pad token, so input_ids
         | 
| 80 | 
            +
                    # cannot be masked directly based on the pad token id.
         | 
| 81 | 
            +
                    attention_mask = torch.zeros_like(input_ids).bool()
         | 
| 82 | 
            +
                    for i, length in enumerate(ori_length):
         | 
| 83 | 
            +
                        attention_mask[i, :length] = True
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    bs, seq_len = input_ids.shape
         | 
| 86 | 
            +
                    position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if seq_parallel_world_size > 1:
         | 
| 89 | 
            +
                    input_ids = pad_for_sequence_parallel(input_ids, pad_index)
         | 
| 90 | 
            +
                    labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
         | 
| 91 | 
            +
                    position_ids = pad_for_sequence_parallel(position_ids, 0)
         | 
| 92 | 
            +
                    if attention_mask is not None:
         | 
| 93 | 
            +
                        attention_mask = pad_for_sequence_parallel(attention_mask, 0)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                if use_varlen_attn:
         | 
| 96 | 
            +
                    max_seqlen = (
         | 
| 97 | 
            +
                        cumulative_len[0][1:] -  # noqa: W504
         | 
| 98 | 
            +
                        cumulative_len[0][:-1]).max().item()
         | 
| 99 | 
            +
                    data_dict = {
         | 
| 100 | 
            +
                        'input_ids': input_ids,
         | 
| 101 | 
            +
                        'cumulative_len': cumulative_len,
         | 
| 102 | 
            +
                        'position_ids': position_ids,
         | 
| 103 | 
            +
                        'labels': labels,
         | 
| 104 | 
            +
                        'max_seqlen': max_seqlen
         | 
| 105 | 
            +
                    }
         | 
| 106 | 
            +
                else:
         | 
| 107 | 
            +
                    data_dict = {
         | 
| 108 | 
            +
                        'input_ids': input_ids,
         | 
| 109 | 
            +
                        'attention_mask': attention_mask,
         | 
| 110 | 
            +
                        'position_ids': position_ids,
         | 
| 111 | 
            +
                        'labels': labels
         | 
| 112 | 
            +
                    }
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                if has_image:
         | 
| 115 | 
            +
                    if all(x.shape == pixel_values[0].shape for x in pixel_values):
         | 
| 116 | 
            +
                        pixel_values = torch.stack(pixel_values, dim=0)
         | 
| 117 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                if has_grounding_image:
         | 
| 120 | 
            +
                    # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
         | 
| 121 | 
            +
                        # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
         | 
| 122 | 
            +
                    data_dict['g_pixel_values'] = grounding_pixel_values
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                if has_mask:
         | 
| 125 | 
            +
                    data_dict['masks'] = object_masks
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if has_bboxes:
         | 
| 128 | 
            +
                    data_dict['bboxes'] = object_bboxes
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                if has_points:
         | 
| 131 | 
            +
                    data_dict['points'] = prompt_points
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                if return_hf_format:
         | 
| 134 | 
            +
                    return data_dict
         | 
| 135 | 
            +
                else:
         | 
| 136 | 
            +
                    return {'data': data_dict, 'data_samples': None}
         | 
    	
        projects/glamm/datasets/gcg_dataset.py
    ADDED
    
    | @@ -0,0 +1,349 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from mmengine import print_log
         | 
| 10 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            from torch.utils.data import Dataset
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
            from pycocotools.coco import COCO
         | 
| 16 | 
            +
            from pycocotools import mask as mask_utils
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from xtuner.registry import BUILDER
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 21 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from projects.glamm.datasets.utils.utils import GCG_QUESTIONS, ANSWER_LIST
         | 
| 26 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 27 | 
            +
            class GCGDataset(Dataset):
         | 
| 28 | 
            +
                def __init__(self,
         | 
| 29 | 
            +
                             image_folder,
         | 
| 30 | 
            +
                             image_processor,
         | 
| 31 | 
            +
                             data_path=None,
         | 
| 32 | 
            +
                             tokenizer=None,
         | 
| 33 | 
            +
                             template_map_fn=None,
         | 
| 34 | 
            +
                             max_length=2048,
         | 
| 35 | 
            +
                             pad_image_to_square=False,
         | 
| 36 | 
            +
                             repeats=1,
         | 
| 37 | 
            +
                             num_classes_per_sample=3,
         | 
| 38 | 
            +
                             extra_image_processor=None):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
                    self.question_templates = GCG_QUESTIONS
         | 
| 41 | 
            +
                    if extra_image_processor is not None:
         | 
| 42 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 43 | 
            +
                    self.num_classes_per_sample = num_classes_per_sample
         | 
| 44 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.tokenizer.add_tokens(
         | 
| 47 | 
            +
                        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
         | 
| 48 | 
            +
                    )
         | 
| 49 | 
            +
                    reg_tokens = ['<bbox>', '<point>']
         | 
| 50 | 
            +
                    segmentation_tokens = ['[SEG]']
         | 
| 51 | 
            +
                    phrase_tokens = ['<p>', '</p>']
         | 
| 52 | 
            +
                    special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
         | 
| 53 | 
            +
                    self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    self.max_length = max_length
         | 
| 56 | 
            +
                    self.template_map_fn = BUILDER.build(template_map_fn)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.text_data = self.json_file_preprocess(data_path, image_folder)
         | 
| 59 | 
            +
                    self.image_folder = image_folder
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.image_processor = BUILDER.build(image_processor)
         | 
| 62 | 
            +
                    size = self.image_processor.crop_size
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if isinstance(size, dict):
         | 
| 65 | 
            +
                        self.image_w, self.image_h = size['width'], size['height']
         | 
| 66 | 
            +
                    elif isinstance(size, int):
         | 
| 67 | 
            +
                        self.image_h, self.image_w = size, size
         | 
| 68 | 
            +
                    else:
         | 
| 69 | 
            +
                        self.image_w, self.image_h = size
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.pad_image_to_square = pad_image_to_square
         | 
| 72 | 
            +
                    self.repeats = repeats
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def json_file_preprocess(self, data_path, image_folder=None):
         | 
| 75 | 
            +
                    with open(data_path, 'r') as f:
         | 
| 76 | 
            +
                        json_data = json.load(f)
         | 
| 77 | 
            +
                    return json_data
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                @property
         | 
| 80 | 
            +
                def modality_length(self):
         | 
| 81 | 
            +
                    length_list = []
         | 
| 82 | 
            +
                    for data_dict in self.text_data:
         | 
| 83 | 
            +
                        cur_len = 100
         | 
| 84 | 
            +
                        length_list.append(cur_len)
         | 
| 85 | 
            +
                    return length_list * self.repeats
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def __len__(self):
         | 
| 88 | 
            +
                    return len(self.text_data) * self.repeats
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def real_len(self):
         | 
| 91 | 
            +
                    return len(self.text_data)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def _parse_annotations(self, ann_info):
         | 
| 94 | 
            +
                    image_path = os.path.join(self.image_folder, ann_info['file_name'])
         | 
| 95 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 96 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 97 | 
            +
                        g_image = np.array(image) # for grounding
         | 
| 98 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 99 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 100 | 
            +
                        ann_info['g_pixel_values'] = g_pixel_values
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    width, height = image.size
         | 
| 103 | 
            +
                    if self.pad_image_to_square:
         | 
| 104 | 
            +
                        image = expand2square(
         | 
| 105 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 106 | 
            +
                    image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 107 | 
            +
                    ann_info['pixel_values'] = image
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    caption = ann_info['caption'].strip('"').strip()
         | 
| 110 | 
            +
                    masks, phrases, tokens_positive = [], [], []
         | 
| 111 | 
            +
                    for word, grounding in ann_info["groundings"].items():
         | 
| 112 | 
            +
                        phrases.append(word)
         | 
| 113 | 
            +
                        tokens_positive.append(grounding["token_positives"])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                        # Convert segmentation to binary mask
         | 
| 116 | 
            +
                        binary_mask = np.zeros((height, width), dtype=np.uint8)
         | 
| 117 | 
            +
                        for rle in grounding["rle_masks"]:
         | 
| 118 | 
            +
                            m = mask_utils.decode(rle).astype(np.uint8)
         | 
| 119 | 
            +
                            binary_mask += m.squeeze()
         | 
| 120 | 
            +
                        masks.append(binary_mask)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    def sort_by_start_index(items, order):
         | 
| 123 | 
            +
                        return [items[i] for i in order]
         | 
| 124 | 
            +
                    
         | 
| 125 | 
            +
                    phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 126 | 
            +
                    masks = sort_by_start_index(masks, phrase_order)
         | 
| 127 | 
            +
                    phrases = sort_by_start_index(phrases, phrase_order)
         | 
| 128 | 
            +
                    tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    ann_info.update({
         | 
| 131 | 
            +
                        'image_path': image_path,
         | 
| 132 | 
            +
                        'caption': caption,
         | 
| 133 | 
            +
                        'masks': masks,
         | 
| 134 | 
            +
                        'phrases': phrases,
         | 
| 135 | 
            +
                        'tokens_positive': tokens_positive,
         | 
| 136 | 
            +
                    })
         | 
| 137 | 
            +
                    return ann_info
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def create_conversation(self, caption, tokens_positive):
         | 
| 140 | 
            +
                    question = random.choice(self.question_templates).strip()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    # Prepare caption with tags
         | 
| 143 | 
            +
                    def tag_caption(caption, tokens):
         | 
| 144 | 
            +
                        for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
         | 
| 145 | 
            +
                            caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
         | 
| 146 | 
            +
                        return caption
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    detailed_answer = tag_caption(caption, tokens_positive)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    question = 'The <image> provides an overview of the picture.\n' + question
         | 
| 151 | 
            +
                    conversation = [{'input': question, 'output': detailed_answer}]
         | 
| 152 | 
            +
                    return conversation
         | 
| 153 | 
            +
                
         | 
| 154 | 
            +
                def __getitem__(self, index):
         | 
| 155 | 
            +
                    index = index % self.real_len()
         | 
| 156 | 
            +
                    data_dict = {}
         | 
| 157 | 
            +
                    ann_info = copy.deepcopy(self.text_data[index])
         | 
| 158 | 
            +
                    ann_info = self._parse_annotations(ann_info)
         | 
| 159 | 
            +
                    
         | 
| 160 | 
            +
                    data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
         | 
| 161 | 
            +
                    data_dict['pixel_values'] = ann_info.pop('pixel_values')
         | 
| 162 | 
            +
                    if len(ann_info['masks']) == 0:
         | 
| 163 | 
            +
                        return self.__getitem__(0)
         | 
| 164 | 
            +
                    data_dict['masks'] = torch.from_numpy(np.stack(ann_info['masks'], axis=0))
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    conversation = self.create_conversation(ann_info['caption'], ann_info['tokens_positive'])
         | 
| 167 | 
            +
                    data_dict['conversation'] = conversation
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 170 | 
            +
                    data_dict.update(result)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
         | 
| 173 | 
            +
                    data_dict.update(result)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    return data_dict
         | 
| 176 | 
            +
             | 
| 177 | 
            +
            class GranDfGCGDataset(GCGDataset):
         | 
| 178 | 
            +
                pass
         | 
| 179 | 
            +
            class RefCOCOgGCGDataset(GCGDataset):
         | 
| 180 | 
            +
                def json_file_preprocess(self, data_path, image_folder=None):
         | 
| 181 | 
            +
                    with open(data_path, 'r') as f:
         | 
| 182 | 
            +
                        json_data = json.load(f)
         | 
| 183 | 
            +
                    return [list(line.values())[0] for line in json_data]
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def _parse_annotations(self, ann_info):
         | 
| 186 | 
            +
                    image_path = os.path.join(self.image_folder, ann_info['img_file_name'])
         | 
| 187 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 188 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 189 | 
            +
                        g_image = np.array(image) # for grounding
         | 
| 190 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 191 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 192 | 
            +
                        ann_info['g_pixel_values'] = g_pixel_values
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    width, height = image.size
         | 
| 195 | 
            +
                    if self.pad_image_to_square:
         | 
| 196 | 
            +
                        image = expand2square(
         | 
| 197 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 198 | 
            +
                    image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 199 | 
            +
                    ann_info['pixel_values'] = image
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    caption = ann_info['caption'].strip('"').strip().lower()
         | 
| 202 | 
            +
                    masks, phrases, tokens_positive = [], [], []
         | 
| 203 | 
            +
                    for detail in ann_info['refs']:
         | 
| 204 | 
            +
                        phrase = detail['sentence']
         | 
| 205 | 
            +
                        if phrase.lower() in caption:
         | 
| 206 | 
            +
                            phrases.append(phrase)
         | 
| 207 | 
            +
                            index = caption.find(phrase)
         | 
| 208 | 
            +
                            end_index = index + len(phrase) if index != -1 else -1
         | 
| 209 | 
            +
                            tokens_positive.append([index, end_index])
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                            binary_mask = np.zeros((height, width), dtype=np.uint8)
         | 
| 212 | 
            +
                            for seg in detail["segmentation"]:
         | 
| 213 | 
            +
                                rles = mask_utils.frPyObjects([seg], height, width)
         | 
| 214 | 
            +
                                m = mask_utils.decode(rles)
         | 
| 215 | 
            +
                                m = m.astype(np.uint8)
         | 
| 216 | 
            +
                                binary_mask += m.squeeze()
         | 
| 217 | 
            +
                            masks.append(binary_mask)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    def sort_by_start_index(items, order):
         | 
| 220 | 
            +
                        return [items[i] for i in order]
         | 
| 221 | 
            +
                    
         | 
| 222 | 
            +
                    phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 223 | 
            +
                    masks = sort_by_start_index(masks, phrase_order)
         | 
| 224 | 
            +
                    phrases = sort_by_start_index(phrases, phrase_order)
         | 
| 225 | 
            +
                    tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    ann_info.update({
         | 
| 228 | 
            +
                        'image_path': image_path,
         | 
| 229 | 
            +
                        'caption': caption,
         | 
| 230 | 
            +
                        'masks': masks,
         | 
| 231 | 
            +
                        'phrases': phrases,
         | 
| 232 | 
            +
                        'tokens_positive': tokens_positive,
         | 
| 233 | 
            +
                    })
         | 
| 234 | 
            +
                    return ann_info
         | 
| 235 | 
            +
             | 
| 236 | 
            +
            class OpenPsgGCGDataset(GCGDataset):
         | 
| 237 | 
            +
                pass
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            class Flickr30kGCGDataset(GCGDataset):
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                def json_file_preprocess(self, data_path, image_folder=None):
         | 
| 242 | 
            +
                    def filter_images(data_infos, min_size):
         | 
| 243 | 
            +
                        return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
         | 
| 244 | 
            +
                    
         | 
| 245 | 
            +
                    self.coco = COCO(data_path)
         | 
| 246 | 
            +
                    self.image_ids = self.coco.getImgIds()
         | 
| 247 | 
            +
                    data_infos = []
         | 
| 248 | 
            +
                    total_ann_ids = []
         | 
| 249 | 
            +
                    removed_img_count = 0
         | 
| 250 | 
            +
                    for img_id in self.image_ids:
         | 
| 251 | 
            +
                        info = self.coco.loadImgs([img_id])[0]
         | 
| 252 | 
            +
                        if len(info['caption'].split(' ')) < 3:
         | 
| 253 | 
            +
                            removed_img_count += 1
         | 
| 254 | 
            +
                            continue
         | 
| 255 | 
            +
                        info['filename'] = info['file_name'].split('_')[-1]
         | 
| 256 | 
            +
                        info['height'] = int(info['height'])
         | 
| 257 | 
            +
                        info['width'] = int(info['width'])
         | 
| 258 | 
            +
                        data_infos.append(info)
         | 
| 259 | 
            +
                        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
         | 
| 260 | 
            +
                        total_ann_ids.extend(ann_ids)
         | 
| 261 | 
            +
                    assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
         | 
| 262 | 
            +
                    print(f'Removed {removed_img_count} images.')
         | 
| 263 | 
            +
                    data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    return data_infos
         | 
| 266 | 
            +
                
         | 
| 267 | 
            +
                def _parse_annotations(self, img_info):
         | 
| 268 | 
            +
                    ann_ids = self.coco.getAnnIds(imgIds=img_info['id'])
         | 
| 269 | 
            +
                    ann_info = self.coco.loadAnns(ann_ids)
         | 
| 270 | 
            +
                    
         | 
| 271 | 
            +
                    annotations = {'phrases': [], 'caption': img_info['caption'], 'masks': [], 'tokens_positive': []}
         | 
| 272 | 
            +
                    image_path = os.path.join(self.image_folder, img_info['file_name'])
         | 
| 273 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 274 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 275 | 
            +
                        g_image = np.array(image) # for grounding
         | 
| 276 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 277 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 278 | 
            +
                        annotations['g_pixel_values'] = g_pixel_values
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    width, height = image.size
         | 
| 281 | 
            +
                    if self.pad_image_to_square:
         | 
| 282 | 
            +
                        image = expand2square(
         | 
| 283 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 284 | 
            +
                    image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 285 | 
            +
                    annotations['pixel_values'] = image
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    for ann in ann_info:
         | 
| 288 | 
            +
                        if ann.get('ignore', False):
         | 
| 289 | 
            +
                            continue
         | 
| 290 | 
            +
                        x1, y1, w, h = ann['bbox']
         | 
| 291 | 
            +
                        inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
         | 
| 292 | 
            +
                        inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
         | 
| 293 | 
            +
                        if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
         | 
| 294 | 
            +
                            continue
         | 
| 295 | 
            +
                        bbox = [x1, y1, x1 + w, y1 + h]
         | 
| 296 | 
            +
                        tokens_positive = ann['tokens_positive']
         | 
| 297 | 
            +
                        phrase = [img_info['caption'][span[0]:span[1]] for span in tokens_positive]
         | 
| 298 | 
            +
                        annotations['phrases'].append(phrase[0])
         | 
| 299 | 
            +
                        annotations['tokens_positive'].append(tokens_positive[0])
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                        rle = ann['sam_mask']
         | 
| 302 | 
            +
                        mask_decoded = mask_utils.decode(rle).astype(np.uint8)
         | 
| 303 | 
            +
                        annotations['masks'].append(mask_decoded)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    def sort_by_start_index(items, order):
         | 
| 306 | 
            +
                        return [items[i] for i in order]
         | 
| 307 | 
            +
                    
         | 
| 308 | 
            +
                    phrase_order = sorted(range(len(annotations['tokens_positive'])), key=lambda x: annotations['tokens_positive'][x][0])
         | 
| 309 | 
            +
                    annotations['masks'] = sort_by_start_index(annotations['masks'], phrase_order)
         | 
| 310 | 
            +
                    annotations['phrases'] = sort_by_start_index(annotations['phrases'], phrase_order)
         | 
| 311 | 
            +
                    annotations['tokens_positive'] = sort_by_start_index(annotations['tokens_positive'], phrase_order)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    return annotations
         | 
| 314 | 
            +
             | 
| 315 | 
            +
            if __name__ == '__main__':
         | 
| 316 | 
            +
                from transformers import CLIPImageProcessor, AutoTokenizer
         | 
| 317 | 
            +
                from third_parts.segment_anything.utils.transforms import ResizeLongestSide
         | 
| 318 | 
            +
                pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
         | 
| 319 | 
            +
                llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
         | 
| 320 | 
            +
                
         | 
| 321 | 
            +
                tokenizer = dict(
         | 
| 322 | 
            +
                    type=AutoTokenizer.from_pretrained,
         | 
| 323 | 
            +
                    pretrained_model_name_or_path=llm_name_or_path)
         | 
| 324 | 
            +
                image_processor = dict(
         | 
| 325 | 
            +
                    type=CLIPImageProcessor.from_pretrained,
         | 
| 326 | 
            +
                    pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
         | 
| 327 | 
            +
                extra_image_processor = dict(
         | 
| 328 | 
            +
                    type=ResizeLongestSide,
         | 
| 329 | 
            +
                    target_length=1024,
         | 
| 330 | 
            +
                )
         | 
| 331 | 
            +
                from xtuner.utils.templates import PROMPT_TEMPLATE
         | 
| 332 | 
            +
                prompt_template = PROMPT_TEMPLATE.vicuna
         | 
| 333 | 
            +
                from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
         | 
| 334 | 
            +
                from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
| 335 | 
            +
                dataset = Flickr30kGCGDataset(
         | 
| 336 | 
            +
                    image_folder='data/flickr30k/flickr30k-images/',
         | 
| 337 | 
            +
                    image_processor=image_processor,
         | 
| 338 | 
            +
                    data_path='./data/GranDf/annotations/train/flickr_mergedGT_GCG_train.json',
         | 
| 339 | 
            +
                    tokenizer=tokenizer,
         | 
| 340 | 
            +
                    template_map_fn=dict(
         | 
| 341 | 
            +
                        type=template_map_fn_factory, template=prompt_template),
         | 
| 342 | 
            +
                    max_length=2048,
         | 
| 343 | 
            +
                    pad_image_to_square=True,
         | 
| 344 | 
            +
                    repeats=1,
         | 
| 345 | 
            +
                    num_classes_per_sample=3,
         | 
| 346 | 
            +
                    extra_image_processor=extra_image_processor)
         | 
| 347 | 
            +
                
         | 
| 348 | 
            +
                for i in range(1000):
         | 
| 349 | 
            +
                    print(dataset[i])
         | 
    	
        projects/glamm/datasets/refcoco_segm_dataset.py
    ADDED
    
    | @@ -0,0 +1,195 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from mmengine import print_log
         | 
| 10 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            from torch.utils.data import Dataset
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
            from pycocotools.coco import COCO
         | 
| 16 | 
            +
            from pycocotools import mask as mask_utils
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from xtuner.registry import BUILDER
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 21 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
         | 
| 26 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from third_parts.mmdet.datasets.refcoco import RefCocoDataset
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ReferSegmDataset(RefCocoDataset):
         | 
| 32 | 
            +
                def __init__(self,
         | 
| 33 | 
            +
                             data_root,
         | 
| 34 | 
            +
                             ann_file=None,
         | 
| 35 | 
            +
                             split_file=None,
         | 
| 36 | 
            +
                             image_processor=None,
         | 
| 37 | 
            +
                             extra_image_processor=None,
         | 
| 38 | 
            +
                             data_prefix=dict(img_path='train2014/'),
         | 
| 39 | 
            +
                             tokenizer=None,
         | 
| 40 | 
            +
                             template_map_fn=None,
         | 
| 41 | 
            +
                             max_length=2048,
         | 
| 42 | 
            +
                             pad_image_to_square=False,
         | 
| 43 | 
            +
                             num_classes_per_sample=3):
         | 
| 44 | 
            +
                    super().__init__(
         | 
| 45 | 
            +
                        data_root=data_root,
         | 
| 46 | 
            +
                        data_prefix=data_prefix,
         | 
| 47 | 
            +
                        pipeline=None,
         | 
| 48 | 
            +
                        ann_file=ann_file,
         | 
| 49 | 
            +
                        split_file=split_file,
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.question_templates = SEG_QUESTIONS
         | 
| 54 | 
            +
                    if extra_image_processor is not None:
         | 
| 55 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 56 | 
            +
                    self.num_classes_per_sample = num_classes_per_sample
         | 
| 57 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.tokenizer.add_tokens(
         | 
| 60 | 
            +
                        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                    reg_tokens = ['<bbox>', '<point>']
         | 
| 63 | 
            +
                    segmentation_tokens = ['[SEG]']
         | 
| 64 | 
            +
                    phrase_tokens = ['<p>', '</p>']
         | 
| 65 | 
            +
                    special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
         | 
| 66 | 
            +
                    self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.max_length = max_length
         | 
| 69 | 
            +
                    self.template_map_fn = BUILDER.build(template_map_fn)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.image_processor = BUILDER.build(image_processor)
         | 
| 72 | 
            +
                    size = self.image_processor.crop_size
         | 
| 73 | 
            +
                    if isinstance(size, dict):
         | 
| 74 | 
            +
                        self.image_w, self.image_h = size['width'], size['height']
         | 
| 75 | 
            +
                    self.pad_image_to_square = pad_image_to_square
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                @property
         | 
| 78 | 
            +
                def modality_length(self):
         | 
| 79 | 
            +
                    import pickle
         | 
| 80 | 
            +
                    length_list = []
         | 
| 81 | 
            +
                    for idx in range(len(self)):
         | 
| 82 | 
            +
                        length_list.append(100)
         | 
| 83 | 
            +
                    # for idx in range(len(self)):
         | 
| 84 | 
            +
                    #     if self.serialize_data:
         | 
| 85 | 
            +
                    #         start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
         | 
| 86 | 
            +
                    #         end_addr = self.data_address[idx].item()
         | 
| 87 | 
            +
                    #         bytes = memoryview(
         | 
| 88 | 
            +
                    #             self.data_bytes[start_addr:end_addr])  # type: ignore
         | 
| 89 | 
            +
                    #         data_dict = pickle.loads(bytes) 
         | 
| 90 | 
            +
                    #     else:
         | 
| 91 | 
            +
                    #         data_dict = copy.deepcopy(self.data_list[idx])
         | 
| 92 | 
            +
                    return length_list
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def _parse_annotations(self, ann_info):
         | 
| 95 | 
            +
                    image_path = ann_info['img_path']
         | 
| 96 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 97 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 98 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 99 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 100 | 
            +
                        g_pixel_values = torch.from_numpy(
         | 
| 101 | 
            +
                            g_image).permute(2, 0, 1).contiguous()
         | 
| 102 | 
            +
                        ann_info['g_pixel_values'] = g_pixel_values
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    width, height = image.size
         | 
| 105 | 
            +
                    if self.pad_image_to_square:
         | 
| 106 | 
            +
                        image = expand2square(
         | 
| 107 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 108 | 
            +
                    image = self.image_processor.preprocess(
         | 
| 109 | 
            +
                        image, return_tensors='pt')['pixel_values'][0]
         | 
| 110 | 
            +
                    ann_info['pixel_values'] = image
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    masks, phrases = [], []
         | 
| 113 | 
            +
                    instances, text = ann_info['instances'], ann_info['text']
         | 
| 114 | 
            +
                    index = np.random.choice(range(len(instances)), min(
         | 
| 115 | 
            +
                        len(instances), self.num_classes_per_sample))
         | 
| 116 | 
            +
                    for idx in index:
         | 
| 117 | 
            +
                        inst = instances[idx]
         | 
| 118 | 
            +
                        phrase = text[idx].lower()
         | 
| 119 | 
            +
                        phrases.append(phrase)
         | 
| 120 | 
            +
                        binary_mask = np.zeros((height, width), dtype=np.uint8)
         | 
| 121 | 
            +
                        for seg in inst["mask"]:
         | 
| 122 | 
            +
                            rles = mask_utils.frPyObjects([seg], height, width)
         | 
| 123 | 
            +
                            m = mask_utils.decode(rles)
         | 
| 124 | 
            +
                            m = m.astype(np.uint8)
         | 
| 125 | 
            +
                            binary_mask += m.squeeze()
         | 
| 126 | 
            +
                        masks.append(binary_mask)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    ann_info.update({
         | 
| 129 | 
            +
                        'masks': masks,
         | 
| 130 | 
            +
                        'phrases': phrases,
         | 
| 131 | 
            +
                    })
         | 
| 132 | 
            +
                    return ann_info
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def __getitem__(self, idx):
         | 
| 135 | 
            +
                    data_dict = {}
         | 
| 136 | 
            +
                    ann_info = super().__getitem__(idx)
         | 
| 137 | 
            +
                    ann_info = self._parse_annotations(ann_info)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
         | 
| 140 | 
            +
                    data_dict['pixel_values'] = ann_info.pop('pixel_values')
         | 
| 141 | 
            +
                    if len(ann_info['masks']) == 0:
         | 
| 142 | 
            +
                        return self.__getitem__(0)
         | 
| 143 | 
            +
                    data_dict['masks'] = torch.from_numpy(
         | 
| 144 | 
            +
                        np.stack(ann_info['masks'], axis=0))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    conversation = []
         | 
| 147 | 
            +
                    for i, phrase in enumerate(ann_info['phrases']):
         | 
| 148 | 
            +
                        question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
         | 
| 149 | 
            +
                        conversation.append(
         | 
| 150 | 
            +
                            {'input': question, 'output': random.choice(ANSWER_LIST)})
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    data_dict['conversation'] = conversation
         | 
| 153 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 154 | 
            +
                    data_dict.update(result)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    result = encode_fn(data_dict, tokenizer=self.tokenizer,
         | 
| 157 | 
            +
                                       max_length=self.max_length, with_image_token=True)
         | 
| 158 | 
            +
                    data_dict.update(result)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return data_dict
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            if __name__ == '__main__':
         | 
| 163 | 
            +
                from transformers import CLIPImageProcessor, AutoTokenizer
         | 
| 164 | 
            +
                from third_parts.segment_anything.utils.transforms import ResizeLongestSide
         | 
| 165 | 
            +
                pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
         | 
| 166 | 
            +
                llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                tokenizer = dict(
         | 
| 169 | 
            +
                    type=AutoTokenizer.from_pretrained,
         | 
| 170 | 
            +
                    pretrained_model_name_or_path=llm_name_or_path)
         | 
| 171 | 
            +
                image_processor = dict(
         | 
| 172 | 
            +
                    type=CLIPImageProcessor.from_pretrained,
         | 
| 173 | 
            +
                    pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
         | 
| 174 | 
            +
                extra_image_processor = dict(
         | 
| 175 | 
            +
                    type=ResizeLongestSide,
         | 
| 176 | 
            +
                    target_length=1024,
         | 
| 177 | 
            +
                )
         | 
| 178 | 
            +
                from xtuner.utils.templates import PROMPT_TEMPLATE
         | 
| 179 | 
            +
                prompt_template = PROMPT_TEMPLATE.vicuna
         | 
| 180 | 
            +
                from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
         | 
| 181 | 
            +
                from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                dataset = ReferSegmDataset(
         | 
| 184 | 
            +
                    tokenizer=tokenizer,
         | 
| 185 | 
            +
                    image_processor=image_processor,
         | 
| 186 | 
            +
                    template_map_fn=dict(
         | 
| 187 | 
            +
                        type=template_map_fn_factory, template=prompt_template),
         | 
| 188 | 
            +
                    extra_image_processor=extra_image_processor,
         | 
| 189 | 
            +
                    data_root='data/coco/',
         | 
| 190 | 
            +
                    data_prefix=dict(img_path='train2014/'),
         | 
| 191 | 
            +
                    ann_file='refcoco+/instances.json',
         | 
| 192 | 
            +
                    split_file='refcoco+/refs(unc).p',
         | 
| 193 | 
            +
                )
         | 
| 194 | 
            +
                for i in range(1000):
         | 
| 195 | 
            +
                    dataset[i]
         | 
    	
        projects/glamm/datasets/region_level_dataset.py
    ADDED
    
    | @@ -0,0 +1,297 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from mmengine import print_log
         | 
| 10 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            from torch.utils.data import Dataset
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
            from pycocotools.coco import COCO
         | 
| 16 | 
            +
            from pycocotools import mask as mask_utils
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from xtuner.registry import BUILDER
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 21 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from projects.glamm.datasets.utils.utils import ANSWER_LIST, REGION_QUESTIONS
         | 
| 26 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            class RegionDataset(Dataset):
         | 
| 30 | 
            +
                def __init__(self,
         | 
| 31 | 
            +
                             image_folder,
         | 
| 32 | 
            +
                             image_processor,
         | 
| 33 | 
            +
                             data_path=None,
         | 
| 34 | 
            +
                             tokenizer=None,
         | 
| 35 | 
            +
                             template_map_fn=None,
         | 
| 36 | 
            +
                             max_length=2048,
         | 
| 37 | 
            +
                             pad_image_to_square=False,
         | 
| 38 | 
            +
                             repeats=1,
         | 
| 39 | 
            +
                             num_classes_per_sample=3,
         | 
| 40 | 
            +
                             extra_image_processor=None):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
         | 
| 44 | 
            +
                    self.question_templates = REGION_QUESTIONS
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    if extra_image_processor is not None:
         | 
| 47 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 48 | 
            +
                    self.num_classes_per_sample = num_classes_per_sample
         | 
| 49 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.tokenizer.add_tokens(
         | 
| 52 | 
            +
                        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
         | 
| 53 | 
            +
                    )
         | 
| 54 | 
            +
                    reg_tokens = ['<bbox>', '<point>']
         | 
| 55 | 
            +
                    segmentation_tokens = ['[SEG]']
         | 
| 56 | 
            +
                    phrase_tokens = ['<p>', '</p>']
         | 
| 57 | 
            +
                    special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
         | 
| 58 | 
            +
                    self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.max_length = max_length
         | 
| 61 | 
            +
                    self.template_map_fn = BUILDER.build(template_map_fn)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    self.text_data = self._load_annotations(data_path, image_folder)
         | 
| 64 | 
            +
                    self.image_folder = image_folder
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.image_processor = BUILDER.build(image_processor)
         | 
| 67 | 
            +
                    size = self.image_processor.crop_size
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    if isinstance(size, dict):
         | 
| 70 | 
            +
                        self.image_w, self.image_h = size['width'], size['height']
         | 
| 71 | 
            +
                    elif isinstance(size, int):
         | 
| 72 | 
            +
                        self.image_h, self.image_w = size, size
         | 
| 73 | 
            +
                    else:
         | 
| 74 | 
            +
                        self.image_w, self.image_h = size
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.pad_image_to_square = pad_image_to_square
         | 
| 77 | 
            +
                    self.repeats = repeats
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def _load_annotations(self, data_path, image_folder=None):
         | 
| 80 | 
            +
                    self.coco = COCO(data_path)
         | 
| 81 | 
            +
                    img_ids = self.coco.getImgIds()
         | 
| 82 | 
            +
                    data_infos = []
         | 
| 83 | 
            +
                    for img_id in img_ids:
         | 
| 84 | 
            +
                        info = self.coco.loadImgs([img_id])[0]
         | 
| 85 | 
            +
                        info['filename'] = info['file_name'].split('_')[-1]
         | 
| 86 | 
            +
                        info['height'] = int(info['height'])
         | 
| 87 | 
            +
                        info['width'] = int(info['width'])
         | 
| 88 | 
            +
                        if min(info['height'], info['width']) < 32:
         | 
| 89 | 
            +
                            continue
         | 
| 90 | 
            +
                        data_infos.append(info)
         | 
| 91 | 
            +
                    return data_infos
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                @property
         | 
| 94 | 
            +
                def modality_length(self):
         | 
| 95 | 
            +
                    length_list = []
         | 
| 96 | 
            +
                    for data_dict in self.text_data:
         | 
| 97 | 
            +
                        cur_len = 100
         | 
| 98 | 
            +
                        length_list.append(cur_len)
         | 
| 99 | 
            +
                    return length_list * self.repeats
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def __len__(self):
         | 
| 102 | 
            +
                    return len(self.text_data) * self.repeats
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def real_len(self):
         | 
| 105 | 
            +
                    return len(self.text_data)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def region_processor(self, orig_size, post_size, bboxes, labels):
         | 
| 108 | 
            +
                    orig_h, orig_w = orig_size
         | 
| 109 | 
            +
                    post_h, post_w = post_size
         | 
| 110 | 
            +
                    y_scale = post_h / orig_h
         | 
| 111 | 
            +
                    x_scale = post_w / orig_w
         | 
| 112 | 
            +
                    shuffle_ids = torch.randperm(len(labels))[:self.num_classes_per_sample]
         | 
| 113 | 
            +
                    selected_bboxes = bboxes[shuffle_ids]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # Ensure selected_bboxes is two-dimensional
         | 
| 116 | 
            +
                    if len(selected_bboxes.shape) == 1:
         | 
| 117 | 
            +
                        selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    selected_labels = [labels[i] for i in shuffle_ids]
         | 
| 120 | 
            +
                    selected_bboxes[:, [0, 2]] *= x_scale
         | 
| 121 | 
            +
                    selected_bboxes[:, [1, 3]] *= y_scale
         | 
| 122 | 
            +
                    selected_bboxes = torch.tensor(
         | 
| 123 | 
            +
                        selected_bboxes, dtype=torch.float32) / post_h
         | 
| 124 | 
            +
                    return selected_bboxes, selected_labels
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def _parse_annotations(self, img_info):
         | 
| 127 | 
            +
                    data_dict = {}
         | 
| 128 | 
            +
                    bboxes, captions = [], []
         | 
| 129 | 
            +
                    ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
         | 
| 130 | 
            +
                    image_path = os.path.join(self.image_folder, img_info['file_name'])
         | 
| 131 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 132 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 133 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 134 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 135 | 
            +
                        g_pixel_values = torch.from_numpy(
         | 
| 136 | 
            +
                            g_image).permute(2, 0, 1).contiguous()
         | 
| 137 | 
            +
                        data_dict['g_pixel_values'] = g_pixel_values
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    orig_w, orig_h = image.size
         | 
| 140 | 
            +
                    if self.pad_image_to_square:
         | 
| 141 | 
            +
                        image = expand2square(
         | 
| 142 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 143 | 
            +
                    image = self.image_processor.preprocess(
         | 
| 144 | 
            +
                        image, return_tensors='pt')['pixel_values'][0]
         | 
| 145 | 
            +
                    post_h, post_w = image.shape[1:3]
         | 
| 146 | 
            +
                    data_dict['pixel_values'] = image
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    for ann in ann_info:
         | 
| 149 | 
            +
                        if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
         | 
| 150 | 
            +
                            continue
         | 
| 151 | 
            +
                        x1, y1, w, h = ann['bbox']
         | 
| 152 | 
            +
                        inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
         | 
| 153 | 
            +
                        inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
         | 
| 154 | 
            +
                        if inter_w * inter_h == 0:
         | 
| 155 | 
            +
                            continue
         | 
| 156 | 
            +
                        bbox = [x1, y1, x1 + w, y1 + h]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        if bbox:
         | 
| 159 | 
            +
                            bboxes.append(bbox)
         | 
| 160 | 
            +
                            captions.append(img_info['caption'])
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    if len(bboxes) == 0:
         | 
| 163 | 
            +
                        return self.__getitem__(0)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    bboxes = np.array(bboxes, dtype=np.float32)
         | 
| 166 | 
            +
                    seg_map = img_info['file_name'].replace('jpg', 'png')
         | 
| 167 | 
            +
                    bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    data_dict['bboxes'] = bboxes
         | 
| 170 | 
            +
                    data_dict['captions'] = captions
         | 
| 171 | 
            +
                    data_dict['seg_map'] = seg_map
         | 
| 172 | 
            +
                    return data_dict
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def create_conversation(self, captions):
         | 
| 175 | 
            +
                    questions = []
         | 
| 176 | 
            +
                    answers = []
         | 
| 177 | 
            +
                    for i, label in enumerate(captions):
         | 
| 178 | 
            +
                        question = random.choice(self.question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
         | 
| 179 | 
            +
                        questions.append(question)
         | 
| 180 | 
            +
                        answers.append(label)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    conversation = []
         | 
| 183 | 
            +
                    for i, (question, answer) in enumerate(zip(questions, answers)):
         | 
| 184 | 
            +
                        if i == 0:
         | 
| 185 | 
            +
                            question = self.begin_str + question
         | 
| 186 | 
            +
                        conversation.append({'input': question, 'output': answer})
         | 
| 187 | 
            +
                    return conversation
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def __getitem__(self, index):
         | 
| 190 | 
            +
                    index = index % self.real_len()
         | 
| 191 | 
            +
                    data_dict = {}
         | 
| 192 | 
            +
                    ann_info = copy.deepcopy(self.text_data[index])
         | 
| 193 | 
            +
                    ann_info = self._parse_annotations(ann_info)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values', None)
         | 
| 196 | 
            +
                    data_dict['pixel_values'] = ann_info.pop('pixel_values')
         | 
| 197 | 
            +
                    data_dict['bboxes'] = ann_info.pop('bboxes', None)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    conversation = self.create_conversation(ann_info['captions'])
         | 
| 200 | 
            +
                    data_dict['conversation'] = conversation
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 203 | 
            +
                    data_dict.update(result)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    result = encode_fn(data_dict, tokenizer=self.tokenizer,
         | 
| 206 | 
            +
                                       max_length=self.max_length, with_image_token=True)
         | 
| 207 | 
            +
                    data_dict.update(result)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    return data_dict
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            class RefCocoGRegionDataset(RegionDataset):
         | 
| 212 | 
            +
                pass
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            class VisualGenomeRegionDataset(RegionDataset):
         | 
| 215 | 
            +
                def _parse_annotations(self, img_info):
         | 
| 216 | 
            +
                    data_dict = {}
         | 
| 217 | 
            +
                    bboxes, captions = [], []
         | 
| 218 | 
            +
                    ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
         | 
| 219 | 
            +
                    image_path = os.path.join(self.image_folder, img_info['file_name'])
         | 
| 220 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 221 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 222 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 223 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 224 | 
            +
                        g_pixel_values = torch.from_numpy(
         | 
| 225 | 
            +
                            g_image).permute(2, 0, 1).contiguous()
         | 
| 226 | 
            +
                        data_dict['g_pixel_values'] = g_pixel_values
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    orig_w, orig_h = image.size
         | 
| 229 | 
            +
                    if self.pad_image_to_square:
         | 
| 230 | 
            +
                        image = expand2square(
         | 
| 231 | 
            +
                            image, tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 232 | 
            +
                    image = self.image_processor.preprocess(
         | 
| 233 | 
            +
                        image, return_tensors='pt')['pixel_values'][0]
         | 
| 234 | 
            +
                    post_h, post_w = image.shape[1:3]
         | 
| 235 | 
            +
                    data_dict['pixel_values'] = image
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    for ann in ann_info:
         | 
| 238 | 
            +
                        if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
         | 
| 239 | 
            +
                            continue
         | 
| 240 | 
            +
                        x1, y1, w, h = ann['bbox']
         | 
| 241 | 
            +
                        inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
         | 
| 242 | 
            +
                        inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
         | 
| 243 | 
            +
                        if inter_w * inter_h == 0:
         | 
| 244 | 
            +
                            continue
         | 
| 245 | 
            +
                        bbox = [x1, y1, x1 + w, y1 + h]
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                        if bbox:
         | 
| 248 | 
            +
                            bboxes.append(bbox)
         | 
| 249 | 
            +
                            captions.append(ann['caption'].strip())
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    if len(bboxes) == 0:
         | 
| 252 | 
            +
                        return self.__getitem__(0)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    bboxes = np.array(bboxes, dtype=np.float32)
         | 
| 255 | 
            +
                    seg_map = img_info['file_name'].replace('jpg', 'png')
         | 
| 256 | 
            +
                    bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    data_dict['bboxes'] = bboxes
         | 
| 259 | 
            +
                    data_dict['captions'] = captions
         | 
| 260 | 
            +
                    data_dict['seg_map'] = seg_map
         | 
| 261 | 
            +
                    return data_dict
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            if __name__ == '__main__':
         | 
| 264 | 
            +
                from transformers import CLIPImageProcessor, AutoTokenizer
         | 
| 265 | 
            +
                from third_parts.segment_anything.utils.transforms import ResizeLongestSide
         | 
| 266 | 
            +
                pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
         | 
| 267 | 
            +
                llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                tokenizer = dict(
         | 
| 270 | 
            +
                    type=AutoTokenizer.from_pretrained,
         | 
| 271 | 
            +
                    pretrained_model_name_or_path=llm_name_or_path)
         | 
| 272 | 
            +
                image_processor = dict(
         | 
| 273 | 
            +
                    type=CLIPImageProcessor.from_pretrained,
         | 
| 274 | 
            +
                    pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
         | 
| 275 | 
            +
                extra_image_processor = dict(
         | 
| 276 | 
            +
                    type=ResizeLongestSide,
         | 
| 277 | 
            +
                    target_length=1024,
         | 
| 278 | 
            +
                )
         | 
| 279 | 
            +
                from xtuner.utils.templates import PROMPT_TEMPLATE
         | 
| 280 | 
            +
                prompt_template = PROMPT_TEMPLATE.vicuna
         | 
| 281 | 
            +
                from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
         | 
| 282 | 
            +
                from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
| 283 | 
            +
                dataset = VisualGenomeRegionDataset(
         | 
| 284 | 
            +
                    image_folder='./data/visual_genome/images',
         | 
| 285 | 
            +
                    image_processor=image_processor,
         | 
| 286 | 
            +
                    data_path='data/visual_genome/train.json',
         | 
| 287 | 
            +
                    tokenizer=tokenizer,
         | 
| 288 | 
            +
                    template_map_fn=dict(
         | 
| 289 | 
            +
                        type=template_map_fn_factory, template=prompt_template),
         | 
| 290 | 
            +
                    max_length=2048,
         | 
| 291 | 
            +
                    pad_image_to_square=False,
         | 
| 292 | 
            +
                    repeats=1,
         | 
| 293 | 
            +
                    num_classes_per_sample=3,
         | 
| 294 | 
            +
                    extra_image_processor=None)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                for i in range(1000):
         | 
| 297 | 
            +
                    print(dataset[i])
         | 
    	
        projects/glamm/datasets/semantic_seg_dataset.py
    ADDED
    
    | @@ -0,0 +1,424 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from mmengine import print_log
         | 
| 10 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            from torch.utils.data import Dataset
         | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch.nn.functional as F
         | 
| 15 | 
            +
            from pycocotools.coco import COCO
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from xtuner.registry import BUILDER
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 20 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
         | 
| 25 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class SemanticSegDataset(Dataset):
         | 
| 29 | 
            +
                def __init__(self,
         | 
| 30 | 
            +
                             image_folder,
         | 
| 31 | 
            +
                             image_processor,
         | 
| 32 | 
            +
                             data_path=None,
         | 
| 33 | 
            +
                             tokenizer=None,
         | 
| 34 | 
            +
                             offline_processed_text_folder=None,
         | 
| 35 | 
            +
                             max_dataset_length=None,
         | 
| 36 | 
            +
                             dataset_map_fn=None,
         | 
| 37 | 
            +
                             template_map_fn=None,
         | 
| 38 | 
            +
                             max_length=2048,
         | 
| 39 | 
            +
                             pad_image_to_square=False,
         | 
| 40 | 
            +
                             num_proc=8,
         | 
| 41 | 
            +
                             lazy=False,
         | 
| 42 | 
            +
                             repeats=1,
         | 
| 43 | 
            +
                             gcg_format=False,
         | 
| 44 | 
            +
                             num_classes_per_sample=3,
         | 
| 45 | 
            +
                             extra_image_processor=None):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    self.gcg_format = gcg_format
         | 
| 48 | 
            +
                    if extra_image_processor is not None:
         | 
| 49 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 50 | 
            +
                    self.num_classes_per_sample = num_classes_per_sample
         | 
| 51 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    self.tokenizer.add_tokens(
         | 
| 54 | 
            +
                        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                    reg_tokens = ['<bbox>', '<point>']
         | 
| 57 | 
            +
                    segmentation_tokens = ['[SEG]']
         | 
| 58 | 
            +
                    phrase_tokens = ['<p>', '</p>']
         | 
| 59 | 
            +
                    special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
         | 
| 60 | 
            +
                    self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    assert offline_processed_text_folder or (data_path and tokenizer)
         | 
| 63 | 
            +
                    self.lazy = lazy
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    self.max_length = max_length
         | 
| 66 | 
            +
                    self.dataset_map_fn = dataset_map_fn
         | 
| 67 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 68 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 69 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 70 | 
            +
                        del self.template_map_fn['type']
         | 
| 71 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if offline_processed_text_folder and data_path:
         | 
| 74 | 
            +
                        print_log(
         | 
| 75 | 
            +
                            'Both `offline_processed_text_folder` and '
         | 
| 76 | 
            +
                            '`data_path` are set, and we load dataset from'
         | 
| 77 | 
            +
                            '`offline_processed_text_folder` '
         | 
| 78 | 
            +
                            f'({offline_processed_text_folder})',
         | 
| 79 | 
            +
                            logger='current',
         | 
| 80 | 
            +
                            level=logging.WARNING)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    if offline_processed_text_folder is not None:
         | 
| 83 | 
            +
                        raise NotImplementedError
         | 
| 84 | 
            +
                    else:
         | 
| 85 | 
            +
                        self.image_label_datas = self.json_file_preprocess(data_path, image_folder)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    self.image_folder = image_folder
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if isinstance(image_processor, dict) or isinstance(image_processor, Config) or isinstance(image_processor, ConfigDict):
         | 
| 90 | 
            +
                        self.image_processor = BUILDER.build(image_processor)
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        self.image_processor = image_processor
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    size = self.image_processor.crop_size
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    if isinstance(size, dict):
         | 
| 97 | 
            +
                        self.image_w, self.image_h = size['width'], size['height']
         | 
| 98 | 
            +
                    elif isinstance(size, int):
         | 
| 99 | 
            +
                        self.image_h, self.image_w = size, size
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        self.image_w, self.image_h = size
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    self.pad_image_to_square = pad_image_to_square
         | 
| 104 | 
            +
                    self.down_ratio = 1
         | 
| 105 | 
            +
                    self.repeats = repeats
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def json_file_preprocess(self, data_path, image_folder):
         | 
| 108 | 
            +
                    # ade20k
         | 
| 109 | 
            +
                    with open(data_path, 'r') as file:
         | 
| 110 | 
            +
                        ade20k_classes = json.load(file)
         | 
| 111 | 
            +
                    ade20k_image_dir = image_folder
         | 
| 112 | 
            +
                    ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
         | 
| 113 | 
            +
                                     img.endswith('.jpg')]
         | 
| 114 | 
            +
                    ade20k_labels = [img.replace(".jpg", ".png").replace(
         | 
| 115 | 
            +
                        "images", "annotations") for img in ade20k_images]
         | 
| 116 | 
            +
                    self.classes = np.array(ade20k_classes)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    ret = []
         | 
| 119 | 
            +
                    for image, label in zip(ade20k_images, ade20k_labels):
         | 
| 120 | 
            +
                        ret.append({"image": image, "label": label})
         | 
| 121 | 
            +
                    return ret
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def __len__(self):
         | 
| 124 | 
            +
                    return len(self.image_label_datas) * self.repeats
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                @property
         | 
| 127 | 
            +
                def modality_length(self):
         | 
| 128 | 
            +
                    length_list = []
         | 
| 129 | 
            +
                    for data_dict in self.image_label_datas:
         | 
| 130 | 
            +
                        length_list.append(100)
         | 
| 131 | 
            +
                    length_list = length_list * self.repeats
         | 
| 132 | 
            +
                    return length_list
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                def real_len(self):
         | 
| 135 | 
            +
                    return len(self.image_label_datas)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def decode_mask(self, label_path):
         | 
| 138 | 
            +
                    label = np.array(Image.open(label_path))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # ade20k
         | 
| 141 | 
            +
                    label = np.where(label == 0, 255, label - 1)
         | 
| 142 | 
            +
                    unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
         | 
| 143 | 
            +
                    if not unique_labels:
         | 
| 144 | 
            +
                        return None, None
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    selected_labels = np.random.choice(unique_labels, min(
         | 
| 147 | 
            +
                        len(unique_labels), self.num_classes_per_sample), replace=False)
         | 
| 148 | 
            +
                    label = torch.from_numpy(label).long()
         | 
| 149 | 
            +
                    masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
         | 
| 150 | 
            +
                    return masks, selected_labels
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def __getitem__(self, index):
         | 
| 153 | 
            +
                    index = index % self.real_len()
         | 
| 154 | 
            +
                    data_dict = copy.deepcopy(self.image_label_datas[index])
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    assert 'image' in data_dict.keys()
         | 
| 157 | 
            +
                    if data_dict.get('image', None) is not None:
         | 
| 158 | 
            +
                        image_file = data_dict['image']
         | 
| 159 | 
            +
                        image = Image.open(image_file).convert('RGB')
         | 
| 160 | 
            +
                        if hasattr(self, 'extra_image_processor'):
         | 
| 161 | 
            +
                            g_image = np.array(image) # for grounding
         | 
| 162 | 
            +
                            g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 163 | 
            +
                            g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 164 | 
            +
                            data_dict['g_pixel_values'] = g_pixel_values
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        ori_width, ori_height = image.size
         | 
| 167 | 
            +
                        if self.pad_image_to_square:
         | 
| 168 | 
            +
                            image = expand2square(image, tuple(int(x * 255)
         | 
| 169 | 
            +
                                                  for x in self.image_processor.image_mean))
         | 
| 170 | 
            +
                        image = self.image_processor.preprocess(
         | 
| 171 | 
            +
                            image, return_tensors='pt')['pixel_values'][0]
         | 
| 172 | 
            +
                        data_dict['pixel_values'] = image
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        # process and get masks
         | 
| 175 | 
            +
                        data_dict['masks'], class_id = self.decode_mask(data_dict['label'])
         | 
| 176 | 
            +
                        if class_id is None:
         | 
| 177 | 
            +
                            return self.__getitem__(0)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        if self.gcg_format:
         | 
| 180 | 
            +
                            pass
         | 
| 181 | 
            +
                        else:
         | 
| 182 | 
            +
                            conversation = []
         | 
| 183 | 
            +
                            for i, c_id in enumerate(class_id):
         | 
| 184 | 
            +
                                question = random.choice(SEG_QUESTIONS).format(
         | 
| 185 | 
            +
                                    class_name=self.classes[c_id].lower())
         | 
| 186 | 
            +
                                if i == 0:
         | 
| 187 | 
            +
                                    question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
         | 
| 188 | 
            +
                                conversation.append(
         | 
| 189 | 
            +
                                    {'input': question, 'output': random.choice(ANSWER_LIST)})
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                        data_dict.update({'conversation': conversation})
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        if hasattr(self.image_processor, 'crop_size'):
         | 
| 194 | 
            +
                            crop_size = self.image_processor.crop_size
         | 
| 195 | 
            +
                        else:
         | 
| 196 | 
            +
                            crop_size = self.image_processor.size
         | 
| 197 | 
            +
                        data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
         | 
| 198 | 
            +
                                                                crop_size['width'])
         | 
| 199 | 
            +
                        data_dict['masks'] = None
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    if self.lazy:
         | 
| 202 | 
            +
                        result = self.template_map_fn(data_dict)
         | 
| 203 | 
            +
                        data_dict.update(result)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                        result = encode_fn(data_dict, tokenizer=self.tokenizer,
         | 
| 206 | 
            +
                                           max_length=self.max_length, with_image_token=True)
         | 
| 207 | 
            +
                        data_dict.update(result)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    return data_dict
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            class ADE20kSemanticSegDataset(SemanticSegDataset):
         | 
| 212 | 
            +
                def __init__(self,
         | 
| 213 | 
            +
                             image_folder,
         | 
| 214 | 
            +
                             image_processor,
         | 
| 215 | 
            +
                             data_path=None,
         | 
| 216 | 
            +
                             tokenizer=None,
         | 
| 217 | 
            +
                             offline_processed_text_folder=None,
         | 
| 218 | 
            +
                             max_dataset_length=None,
         | 
| 219 | 
            +
                             dataset_map_fn=None,
         | 
| 220 | 
            +
                             template_map_fn=None,
         | 
| 221 | 
            +
                             max_length=2048,
         | 
| 222 | 
            +
                             pad_image_to_square=False,
         | 
| 223 | 
            +
                             num_proc=8,
         | 
| 224 | 
            +
                             lazy=False,
         | 
| 225 | 
            +
                             repeats=1,
         | 
| 226 | 
            +
                             gcg_format=False,
         | 
| 227 | 
            +
                             num_classes_per_sample=3,
         | 
| 228 | 
            +
                             extra_image_processor=None):
         | 
| 229 | 
            +
                    super().__init__(
         | 
| 230 | 
            +
                        image_folder=image_folder,
         | 
| 231 | 
            +
                        image_processor=image_processor,
         | 
| 232 | 
            +
                        data_path=data_path,
         | 
| 233 | 
            +
                        tokenizer=tokenizer,
         | 
| 234 | 
            +
                        offline_processed_text_folder=offline_processed_text_folder,
         | 
| 235 | 
            +
                        max_dataset_length=max_dataset_length,
         | 
| 236 | 
            +
                        dataset_map_fn=dataset_map_fn,
         | 
| 237 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 238 | 
            +
                        max_length=max_length,
         | 
| 239 | 
            +
                        pad_image_to_square=pad_image_to_square,
         | 
| 240 | 
            +
                        num_proc=num_proc,
         | 
| 241 | 
            +
                        lazy=lazy,
         | 
| 242 | 
            +
                        repeats=repeats,
         | 
| 243 | 
            +
                        gcg_format=gcg_format,
         | 
| 244 | 
            +
                        num_classes_per_sample=num_classes_per_sample,
         | 
| 245 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 246 | 
            +
                    )
         | 
| 247 | 
            +
             | 
| 248 | 
            +
            class COCOStuffSemanticSegDataset(SemanticSegDataset):
         | 
| 249 | 
            +
                def __init__(self,
         | 
| 250 | 
            +
                             image_folder,
         | 
| 251 | 
            +
                             image_processor,
         | 
| 252 | 
            +
                             data_path=None,
         | 
| 253 | 
            +
                             tokenizer=None,
         | 
| 254 | 
            +
                             offline_processed_text_folder=None,
         | 
| 255 | 
            +
                             max_dataset_length=None,
         | 
| 256 | 
            +
                             dataset_map_fn=None,
         | 
| 257 | 
            +
                             template_map_fn=None,
         | 
| 258 | 
            +
                             max_length=2048,
         | 
| 259 | 
            +
                             pad_image_to_square=False,
         | 
| 260 | 
            +
                             num_proc=8,
         | 
| 261 | 
            +
                             lazy=False,
         | 
| 262 | 
            +
                             repeats=1,
         | 
| 263 | 
            +
                             label_path=None,
         | 
| 264 | 
            +
                             gcg_format=False,
         | 
| 265 | 
            +
                             num_classes_per_sample=3,
         | 
| 266 | 
            +
                             extra_image_processor=None):
         | 
| 267 | 
            +
                    self.label_path = label_path
         | 
| 268 | 
            +
                    super().__init__(
         | 
| 269 | 
            +
                        image_folder=image_folder,
         | 
| 270 | 
            +
                        image_processor=image_processor,
         | 
| 271 | 
            +
                        data_path=data_path,
         | 
| 272 | 
            +
                        tokenizer=tokenizer,
         | 
| 273 | 
            +
                        offline_processed_text_folder=offline_processed_text_folder,
         | 
| 274 | 
            +
                        max_dataset_length=max_dataset_length,
         | 
| 275 | 
            +
                        dataset_map_fn=dataset_map_fn,
         | 
| 276 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 277 | 
            +
                        max_length=max_length,
         | 
| 278 | 
            +
                        pad_image_to_square=pad_image_to_square,
         | 
| 279 | 
            +
                        num_proc=num_proc,
         | 
| 280 | 
            +
                        lazy=lazy,
         | 
| 281 | 
            +
                        repeats=repeats,
         | 
| 282 | 
            +
                        gcg_format=gcg_format,
         | 
| 283 | 
            +
                        num_classes_per_sample=num_classes_per_sample,
         | 
| 284 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 285 | 
            +
                    )
         | 
| 286 | 
            +
                    self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)}
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                def json_file_preprocess(self, data_path, image_folder):
         | 
| 289 | 
            +
                    # coco stuff
         | 
| 290 | 
            +
                    assert self.label_path is not None
         | 
| 291 | 
            +
                    with open(data_path, 'r') as file:
         | 
| 292 | 
            +
                        cocostuff_classes = [line.strip().split(": ")[-1]
         | 
| 293 | 
            +
                                             for line in file.readlines()[1:]]
         | 
| 294 | 
            +
                    coco_stuff_image_dir = image_folder
         | 
| 295 | 
            +
                    coco_stuff_label_dir = self.label_path
         | 
| 296 | 
            +
                    coco_stuff_labels = glob.glob(
         | 
| 297 | 
            +
                        os.path.join(coco_stuff_label_dir, "*.png"))
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir)
         | 
| 300 | 
            +
                                         for label in coco_stuff_labels]
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    self.classes = np.array(cocostuff_classes)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    ret = []
         | 
| 305 | 
            +
                    for image, label in zip(coco_stuff_images, coco_stuff_labels):
         | 
| 306 | 
            +
                        ret.append({"image": image, "label": label})
         | 
| 307 | 
            +
                    return ret
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def decode_mask(self, label_path):
         | 
| 310 | 
            +
                    label = np.array(Image.open(label_path))
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    # coco stuff
         | 
| 313 | 
            +
                    ignored_classes = [index for class_name,
         | 
| 314 | 
            +
                                       index in self.cocostuff_class2index.items() if "-" in class_name]
         | 
| 315 | 
            +
                    label = np.where(np.isin(label, ignored_classes), 255, label)
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
         | 
| 318 | 
            +
                    if not unique_labels:
         | 
| 319 | 
            +
                        print("No valid label !!!")
         | 
| 320 | 
            +
                        return None, None
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # only choose 1
         | 
| 323 | 
            +
                    selected_labels = np.random.choice(unique_labels, min(
         | 
| 324 | 
            +
                        len(unique_labels), self.num_classes_per_sample), replace=False)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    label = torch.from_numpy(label).long()
         | 
| 327 | 
            +
                    masks = torch.stack(
         | 
| 328 | 
            +
                        [label == class_id for class_id in selected_labels], dim=0)
         | 
| 329 | 
            +
                    return masks, selected_labels
         | 
| 330 | 
            +
             | 
| 331 | 
            +
            class PascalPartSemanticSegDataset(SemanticSegDataset):
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def json_file_preprocess(self, data_path, image_folder):
         | 
| 334 | 
            +
                    self.coco_api = COCO(data_path)
         | 
| 335 | 
            +
                    img_ids = self.coco_api.getImgIds()
         | 
| 336 | 
            +
                    all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
         | 
| 337 | 
            +
                    class_map_pascal_part = {}
         | 
| 338 | 
            +
                    for cat in all_classes:
         | 
| 339 | 
            +
                        cat_main, cat_part = cat["name"].strip().split(":")
         | 
| 340 | 
            +
                        name = (cat_main, cat_part)
         | 
| 341 | 
            +
                        class_map_pascal_part[cat["id"]] = name
         | 
| 342 | 
            +
                    self.classes = class_map_pascal_part
         | 
| 343 | 
            +
                    return img_ids
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def __getitem__(self, index):
         | 
| 346 | 
            +
                    index = index % self.real_len()
         | 
| 347 | 
            +
                    img_id = self.image_label_datas[index]
         | 
| 348 | 
            +
                    img_info = self.coco_api.loadImgs([img_id])[0]
         | 
| 349 | 
            +
                    file_name = img_info["file_name"]
         | 
| 350 | 
            +
                    data_dict = {}
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    image_file = os.path.join(self.image_folder, file_name)
         | 
| 353 | 
            +
                    image = Image.open(image_file).convert('RGB')
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 356 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 357 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 358 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 359 | 
            +
                        data_dict['g_pixel_values'] = g_pixel_values
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    if self.pad_image_to_square:
         | 
| 362 | 
            +
                        image = expand2square(
         | 
| 363 | 
            +
                            image,  tuple(int(x * 255) for x in self.image_processor.image_mean))
         | 
| 364 | 
            +
                    image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
         | 
| 365 | 
            +
                    data_dict['pixel_values'] = image
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
         | 
| 368 | 
            +
                    annotations = self.coco_api.loadAnns(annotation_ids)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    if not annotations:
         | 
| 371 | 
            +
                        return self.__getitem__(0)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    sampled_anns = np.random.choice(annotations, min(
         | 
| 374 | 
            +
                        len(annotations), self.num_classes_per_sample), replace=False)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    conversation = []
         | 
| 377 | 
            +
                    for i, ann in enumerate(sampled_anns):
         | 
| 378 | 
            +
                        cat_id = ann['category_id']
         | 
| 379 | 
            +
                        sampled_cls = self.classes[cat_id]
         | 
| 380 | 
            +
                        if isinstance(sampled_cls, tuple):
         | 
| 381 | 
            +
                            obj, part = sampled_cls
         | 
| 382 | 
            +
                            name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
         | 
| 383 | 
            +
                        else:
         | 
| 384 | 
            +
                            name = sampled_cls
         | 
| 385 | 
            +
                        question = random.choice(SEG_QUESTIONS).format(class_name=name)
         | 
| 386 | 
            +
                        if i == 0:
         | 
| 387 | 
            +
                            question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
         | 
| 388 | 
            +
                        conversation.append(
         | 
| 389 | 
            +
                            {'input': question, 'output': random.choice(ANSWER_LIST)})
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    masks = [self.coco_api.annToMask(ann) for ann in sampled_anns]
         | 
| 392 | 
            +
                    masks = np.stack(masks, axis=0)
         | 
| 393 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    data_dict['masks'] = masks
         | 
| 396 | 
            +
                    data_dict['conversation'] = conversation
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    if self.lazy:
         | 
| 399 | 
            +
                        result = self.template_map_fn(data_dict)
         | 
| 400 | 
            +
                        data_dict.update(result)
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                        result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
         | 
| 403 | 
            +
                        data_dict.update(result)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    return data_dict
         | 
| 406 | 
            +
             | 
| 407 | 
            +
            class PacoSemanticSegDataset(PascalPartSemanticSegDataset):
         | 
| 408 | 
            +
                def json_file_preprocess(self, data_path, image_folder):
         | 
| 409 | 
            +
                    self.coco_api = COCO(data_path)
         | 
| 410 | 
            +
                    all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
         | 
| 411 | 
            +
                    class_map_paco = {}
         | 
| 412 | 
            +
                    for cat in all_classes:
         | 
| 413 | 
            +
                        cat_split = cat["name"].strip().split(":")
         | 
| 414 | 
            +
                        if len(cat_split) == 1:
         | 
| 415 | 
            +
                            name = cat_split[0].split("_(")[0]
         | 
| 416 | 
            +
                        else:
         | 
| 417 | 
            +
                            assert len(cat_split) == 2
         | 
| 418 | 
            +
                            obj, part = cat_split
         | 
| 419 | 
            +
                            obj = obj.split("_(")[0]
         | 
| 420 | 
            +
                            part = part.split("_(")[0]
         | 
| 421 | 
            +
                            name = (obj, part)
         | 
| 422 | 
            +
                        class_map_paco[cat["id"]] = name
         | 
| 423 | 
            +
                    self.classes = class_map_paco
         | 
| 424 | 
            +
                    return self.coco_api.getImgIds()
         | 
    	
        projects/glamm/datasets/utils/ade20k_classes.json
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [
         | 
| 2 | 
            +
                "wall", "building", "sky", "floor", "tree", "ceiling", "road",
         | 
| 3 | 
            +
                "bed", "windowpane", "grass", "cabinet", "sidewalk",
         | 
| 4 | 
            +
                "person", "earth", "door", "table", "mountain", "plant",
         | 
| 5 | 
            +
                "curtain", "chair", "car", "water", "painting", "sofa",
         | 
| 6 | 
            +
                "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
         | 
| 7 | 
            +
                "seat", "fence", "desk", "rock", "wardrobe", "lamp",
         | 
| 8 | 
            +
                "bathtub", "railing", "cushion", "base", "box", "column",
         | 
| 9 | 
            +
                "signboard", "chest of drawers", "counter", "sand", "sink",
         | 
| 10 | 
            +
                "skyscraper", "fireplace", "refrigerator", "grandstand",
         | 
| 11 | 
            +
                "path", "stairs", "runway", "case", "pool table", "pillow",
         | 
| 12 | 
            +
                "screen door", "stairway", "river", "bridge", "bookcase",
         | 
| 13 | 
            +
                "blind", "coffee table", "toilet", "flower", "book", "hill",
         | 
| 14 | 
            +
                "bench", "countertop", "stove", "palm", "kitchen island",
         | 
| 15 | 
            +
                "computer", "swivel chair", "boat", "bar", "arcade machine",
         | 
| 16 | 
            +
                "hovel", "bus", "towel", "light", "truck", "tower",
         | 
| 17 | 
            +
                "chandelier", "awning", "streetlight", "booth",
         | 
| 18 | 
            +
                "television receiver", "airplane", "dirt track", "apparel",
         | 
| 19 | 
            +
                "pole", "land", "bannister", "escalator", "ottoman", "bottle",
         | 
| 20 | 
            +
                "buffet", "poster", "stage", "van", "ship", "fountain",
         | 
| 21 | 
            +
                "conveyer belt", "canopy", "washer", "plaything",
         | 
| 22 | 
            +
                "swimming pool", "stool", "barrel", "basket", "waterfall",
         | 
| 23 | 
            +
                "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
         | 
| 24 | 
            +
                "step", "tank", "trade name", "microwave", "pot", "animal",
         | 
| 25 | 
            +
                "bicycle", "lake", "dishwasher", "screen", "blanket",
         | 
| 26 | 
            +
                "sculpture", "hood", "sconce", "vase", "traffic light",
         | 
| 27 | 
            +
                "tray", "ashcan", "fan", "pier", "crt screen", "plate",
         | 
| 28 | 
            +
                "monitor", "bulletin board", "shower", "radiator", "glass",
         | 
| 29 | 
            +
                "clock", "flag"
         | 
| 30 | 
            +
            ]
         | 
    	
        projects/glamm/datasets/utils/cocostuff_classes.txt
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            0: unlabeled
         | 
| 2 | 
            +
            1: person
         | 
| 3 | 
            +
            2: bicycle
         | 
| 4 | 
            +
            3: car
         | 
| 5 | 
            +
            4: motorcycle
         | 
| 6 | 
            +
            5: airplane
         | 
| 7 | 
            +
            6: bus
         | 
| 8 | 
            +
            7: train
         | 
| 9 | 
            +
            8: truck
         | 
| 10 | 
            +
            9: boat
         | 
| 11 | 
            +
            10: traffic light
         | 
| 12 | 
            +
            11: fire hydrant
         | 
| 13 | 
            +
            12: street sign
         | 
| 14 | 
            +
            13: stop sign
         | 
| 15 | 
            +
            14: parking meter
         | 
| 16 | 
            +
            15: bench
         | 
| 17 | 
            +
            16: bird
         | 
| 18 | 
            +
            17: cat
         | 
| 19 | 
            +
            18: dog
         | 
| 20 | 
            +
            19: horse
         | 
| 21 | 
            +
            20: sheep
         | 
| 22 | 
            +
            21: cow
         | 
| 23 | 
            +
            22: elephant
         | 
| 24 | 
            +
            23: bear
         | 
| 25 | 
            +
            24: zebra
         | 
| 26 | 
            +
            25: giraffe
         | 
| 27 | 
            +
            26: hat
         | 
| 28 | 
            +
            27: backpack
         | 
| 29 | 
            +
            28: umbrella
         | 
| 30 | 
            +
            29: shoe
         | 
| 31 | 
            +
            30: eye glasses
         | 
| 32 | 
            +
            31: handbag
         | 
| 33 | 
            +
            32: tie
         | 
| 34 | 
            +
            33: suitcase
         | 
| 35 | 
            +
            34: frisbee
         | 
| 36 | 
            +
            35: skis
         | 
| 37 | 
            +
            36: snowboard
         | 
| 38 | 
            +
            37: sports ball
         | 
| 39 | 
            +
            38: kite
         | 
| 40 | 
            +
            39: baseball bat
         | 
| 41 | 
            +
            40: baseball glove
         | 
| 42 | 
            +
            41: skateboard
         | 
| 43 | 
            +
            42: surfboard
         | 
| 44 | 
            +
            43: tennis racket
         | 
| 45 | 
            +
            44: bottle
         | 
| 46 | 
            +
            45: plate
         | 
| 47 | 
            +
            46: wine glass
         | 
| 48 | 
            +
            47: cup
         | 
| 49 | 
            +
            48: fork
         | 
| 50 | 
            +
            49: knife
         | 
| 51 | 
            +
            50: spoon
         | 
| 52 | 
            +
            51: bowl
         | 
| 53 | 
            +
            52: banana
         | 
| 54 | 
            +
            53: apple
         | 
| 55 | 
            +
            54: sandwich
         | 
| 56 | 
            +
            55: orange
         | 
| 57 | 
            +
            56: broccoli
         | 
| 58 | 
            +
            57: carrot
         | 
| 59 | 
            +
            58: hot dog
         | 
| 60 | 
            +
            59: pizza
         | 
| 61 | 
            +
            60: donut
         | 
| 62 | 
            +
            61: cake
         | 
| 63 | 
            +
            62: chair
         | 
| 64 | 
            +
            63: couch
         | 
| 65 | 
            +
            64: potted plant
         | 
| 66 | 
            +
            65: bed
         | 
| 67 | 
            +
            66: mirror
         | 
| 68 | 
            +
            67: dining table
         | 
| 69 | 
            +
            68: window
         | 
| 70 | 
            +
            69: desk
         | 
| 71 | 
            +
            70: toilet
         | 
| 72 | 
            +
            71: door
         | 
| 73 | 
            +
            72: tv
         | 
| 74 | 
            +
            73: laptop
         | 
| 75 | 
            +
            74: mouse
         | 
| 76 | 
            +
            75: remote
         | 
| 77 | 
            +
            76: keyboard
         | 
| 78 | 
            +
            77: cell phone
         | 
| 79 | 
            +
            78: microwave
         | 
| 80 | 
            +
            79: oven
         | 
| 81 | 
            +
            80: toaster
         | 
| 82 | 
            +
            81: sink
         | 
| 83 | 
            +
            82: refrigerator
         | 
| 84 | 
            +
            83: blender
         | 
| 85 | 
            +
            84: book
         | 
| 86 | 
            +
            85: clock
         | 
| 87 | 
            +
            86: vase
         | 
| 88 | 
            +
            87: scissors
         | 
| 89 | 
            +
            88: teddy bear
         | 
| 90 | 
            +
            89: hair drier
         | 
| 91 | 
            +
            90: toothbrush
         | 
| 92 | 
            +
            91: hair brush
         | 
| 93 | 
            +
            92: banner
         | 
| 94 | 
            +
            93: blanket
         | 
| 95 | 
            +
            94: branch
         | 
| 96 | 
            +
            95: bridge
         | 
| 97 | 
            +
            96: building-other
         | 
| 98 | 
            +
            97: bush
         | 
| 99 | 
            +
            98: cabinet
         | 
| 100 | 
            +
            99: cage
         | 
| 101 | 
            +
            100: cardboard
         | 
| 102 | 
            +
            101: carpet
         | 
| 103 | 
            +
            102: ceiling-other
         | 
| 104 | 
            +
            103: ceiling-tile
         | 
| 105 | 
            +
            104: cloth
         | 
| 106 | 
            +
            105: clothes
         | 
| 107 | 
            +
            106: clouds
         | 
| 108 | 
            +
            107: counter
         | 
| 109 | 
            +
            108: cupboard
         | 
| 110 | 
            +
            109: curtain
         | 
| 111 | 
            +
            110: desk-stuff
         | 
| 112 | 
            +
            111: dirt
         | 
| 113 | 
            +
            112: door-stuff
         | 
| 114 | 
            +
            113: fence
         | 
| 115 | 
            +
            114: floor-marble
         | 
| 116 | 
            +
            115: floor-other
         | 
| 117 | 
            +
            116: floor-stone
         | 
| 118 | 
            +
            117: floor-tile
         | 
| 119 | 
            +
            118: floor-wood
         | 
| 120 | 
            +
            119: flower
         | 
| 121 | 
            +
            120: fog
         | 
| 122 | 
            +
            121: food-other
         | 
| 123 | 
            +
            122: fruit
         | 
| 124 | 
            +
            123: furniture-other
         | 
| 125 | 
            +
            124: grass
         | 
| 126 | 
            +
            125: gravel
         | 
| 127 | 
            +
            126: ground-other
         | 
| 128 | 
            +
            127: hill
         | 
| 129 | 
            +
            128: house
         | 
| 130 | 
            +
            129: leaves
         | 
| 131 | 
            +
            130: light
         | 
| 132 | 
            +
            131: mat
         | 
| 133 | 
            +
            132: metal
         | 
| 134 | 
            +
            133: mirror-stuff
         | 
| 135 | 
            +
            134: moss
         | 
| 136 | 
            +
            135: mountain
         | 
| 137 | 
            +
            136: mud
         | 
| 138 | 
            +
            137: napkin
         | 
| 139 | 
            +
            138: net
         | 
| 140 | 
            +
            139: paper
         | 
| 141 | 
            +
            140: pavement
         | 
| 142 | 
            +
            141: pillow
         | 
| 143 | 
            +
            142: plant-other
         | 
| 144 | 
            +
            143: plastic
         | 
| 145 | 
            +
            144: platform
         | 
| 146 | 
            +
            145: playingfield
         | 
| 147 | 
            +
            146: railing
         | 
| 148 | 
            +
            147: railroad
         | 
| 149 | 
            +
            148: river
         | 
| 150 | 
            +
            149: road
         | 
| 151 | 
            +
            150: rock
         | 
| 152 | 
            +
            151: roof
         | 
| 153 | 
            +
            152: rug
         | 
| 154 | 
            +
            153: salad
         | 
| 155 | 
            +
            154: sand
         | 
| 156 | 
            +
            155: sea
         | 
| 157 | 
            +
            156: shelf
         | 
| 158 | 
            +
            157: sky
         | 
| 159 | 
            +
            158: skyscraper
         | 
| 160 | 
            +
            159: snow
         | 
| 161 | 
            +
            160: solid-other
         | 
| 162 | 
            +
            161: stairs
         | 
| 163 | 
            +
            162: stone
         | 
| 164 | 
            +
            163: straw
         | 
| 165 | 
            +
            164: structural-other
         | 
| 166 | 
            +
            165: table
         | 
| 167 | 
            +
            166: tent
         | 
| 168 | 
            +
            167: textile-other
         | 
| 169 | 
            +
            168: towel
         | 
| 170 | 
            +
            169: tree
         | 
| 171 | 
            +
            170: vegetable
         | 
| 172 | 
            +
            171: wall-brick
         | 
| 173 | 
            +
            172: wall-concrete
         | 
| 174 | 
            +
            173: wall-other
         | 
| 175 | 
            +
            174: wall-panel
         | 
| 176 | 
            +
            175: wall-stone
         | 
| 177 | 
            +
            176: wall-tile
         | 
| 178 | 
            +
            177: wall-wood
         | 
| 179 | 
            +
            178: water-other
         | 
| 180 | 
            +
            179: waterdrops
         | 
| 181 | 
            +
            180: window-blind
         | 
| 182 | 
            +
            181: window-other
         | 
| 183 | 
            +
            182: wood
         | 
    	
        projects/glamm/datasets/utils/utils.py
    ADDED
    
    | @@ -0,0 +1,131 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from PIL import Image
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def expand2square(pil_img, background_color):
         | 
| 6 | 
            +
                width, height = pil_img.size
         | 
| 7 | 
            +
                if width == height:
         | 
| 8 | 
            +
                    return pil_img
         | 
| 9 | 
            +
                elif width > height:
         | 
| 10 | 
            +
                    result = Image.new(pil_img.mode, (width, width), background_color)
         | 
| 11 | 
            +
                    result.paste(pil_img, (0, (width - height) // 2))
         | 
| 12 | 
            +
                    return result
         | 
| 13 | 
            +
                else:
         | 
| 14 | 
            +
                    result = Image.new(pil_img.mode, (height, height), background_color)
         | 
| 15 | 
            +
                    result.paste(pil_img, ((height - width) // 2, 0))
         | 
| 16 | 
            +
                    return result
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
            CAPTION_QUESTIONS = [
         | 
| 19 | 
            +
                'Could you please give me a detailed description of the image?',
         | 
| 20 | 
            +
                'Can you provide a thorough description of the this image?',
         | 
| 21 | 
            +
                'Please provide a thorough description of the this image',
         | 
| 22 | 
            +
                'Please provide a thorough description of the this image.',
         | 
| 23 | 
            +
                'Please describe in detail the contents of the image.',
         | 
| 24 | 
            +
                'Please describe in detail the contents of the image',
         | 
| 25 | 
            +
                'Could you give a comprehensive explanation of what can be found within this picture?',
         | 
| 26 | 
            +
                'Could you give me an elaborate explanation of this picture?',
         | 
| 27 | 
            +
                'Could you provide me with a detailed analysis of this photo?',
         | 
| 28 | 
            +
                'Could you please give me a detailed description of the image?',
         | 
| 29 | 
            +
                'Can you provide a thorough description of the this image?',
         | 
| 30 | 
            +
                'Please describe in detail the contents of the image',
         | 
| 31 | 
            +
                'Please describe in detail the contents of the image.',
         | 
| 32 | 
            +
                'Can you give a comprehensive explanation of this photo',
         | 
| 33 | 
            +
                'Please provide an elaborate explanation of this picture.',
         | 
| 34 | 
            +
                'Please provide an elaborate explanation of this picture',
         | 
| 35 | 
            +
                'Could you provide me with a detailed analysis of this photo',
         | 
| 36 | 
            +
            ]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            REGION_QUESTIONS = [
         | 
| 39 | 
            +
                'Can you provide me with a detailed description of the region in the picture marked by <region>?',
         | 
| 40 | 
            +
                "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
         | 
| 41 | 
            +
                'What can you tell me about the region indicated by <region> in the image?',
         | 
| 42 | 
            +
                "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
         | 
| 43 | 
            +
                'Could you describe the region shown as <region> in the picture in great detail?',
         | 
| 44 | 
            +
                'What details can you give me about the region outlined by <region> in the photo?',
         | 
| 45 | 
            +
                'Please provide me with a comprehensive description of the region marked with <region> in the image.',
         | 
| 46 | 
            +
                'Can you give me a detailed account of the region labeled as <region> in the picture?',
         | 
| 47 | 
            +
                "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
         | 
| 48 | 
            +
                'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
         | 
| 49 | 
            +
                'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
         | 
| 50 | 
            +
                "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
         | 
| 51 | 
            +
                'What can you tell me about the region indicated by <region> in the image, exactly?',
         | 
| 52 | 
            +
                "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
         | 
| 53 | 
            +
                'Could you describe the region shown as <region> in the picture in great detail, please?',
         | 
| 54 | 
            +
                'What details can you give me about the region outlined by <region> in the photo, please?',
         | 
| 55 | 
            +
                'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
         | 
| 56 | 
            +
                'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
         | 
| 57 | 
            +
                "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
         | 
| 58 | 
            +
                'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
         | 
| 59 | 
            +
            ]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            REGION_GROUP_QUESTIONS = [
         | 
| 62 | 
            +
                'Could you please give me a detailed description of these areas <region>?',
         | 
| 63 | 
            +
                'Can you provide a thorough description of the regions <region> in this image?',
         | 
| 64 | 
            +
                'Please describe in detail the contents of the boxed areas <region>.',
         | 
| 65 | 
            +
                'Could you give a comprehensive explanation of what can be found within <region> in the picture?',
         | 
| 66 | 
            +
                'Could you give me an elaborate explanation of the <region> regions in this picture?',
         | 
| 67 | 
            +
                'Can you provide a comprehensive description of the areas identified by <region> in this photo?',
         | 
| 68 | 
            +
                'Help me understand the specific locations labeled <region> in this picture in detail, please.',
         | 
| 69 | 
            +
                'What is the detailed information about the areas marked by <region> in this image?',
         | 
| 70 | 
            +
                'Could you provide me with a detailed analysis of the regions designated <region> in this photo?',
         | 
| 71 | 
            +
                'What are the specific features of the areas marked <region> in this picture that you can describe in detail?',
         | 
| 72 | 
            +
                'Could you elaborate on the regions identified by <region> in this image?',
         | 
| 73 | 
            +
                'What can you tell me about the areas labeled <region> in this picture?',
         | 
| 74 | 
            +
                'Can you provide a thorough analysis of the specific locations designated <region> in this photo?',
         | 
| 75 | 
            +
                'I am interested in learning more about the regions marked <region> in this image. Can you provide me with more information?',
         | 
| 76 | 
            +
                'Could you please provide a detailed description of the areas identified by <region> in this photo?',
         | 
| 77 | 
            +
                'What is the significance of the regions labeled <region> in this picture?',
         | 
| 78 | 
            +
                'I would like to know more about the specific locations designated <region> in this image. Can you provide me with more information?',
         | 
| 79 | 
            +
                'Can you provide a detailed breakdown of the regions marked <region> in this photo?',
         | 
| 80 | 
            +
                'What specific features can you tell me about the areas identified by <region> in this picture?',
         | 
| 81 | 
            +
                'Could you please provide a comprehensive explanation of the locations labeled <region> in this image?',
         | 
| 82 | 
            +
                'Can you provide a detailed account of the regions designated <region> in this photo?',
         | 
| 83 | 
            +
                'I am curious about the areas marked <region> in this picture. Can you provide me with a detailed analysis?',
         | 
| 84 | 
            +
                'What important details can you tell me about the specific locations identified by <region> in this image?',
         | 
| 85 | 
            +
                'Could you please provide a detailed description of the regions labeled <region> in this photo?',
         | 
| 86 | 
            +
                'What can you tell me about the features of the areas designated <region> in this picture?',
         | 
| 87 | 
            +
                'Can you provide a comprehensive overview of the regions marked <region> in this image?',
         | 
| 88 | 
            +
                'I would like to know more about the specific locations identified by <region> in this photo. Can you provide me with more information?',
         | 
| 89 | 
            +
                'What is the detailed information you have on the areas labeled <region> in this picture?',
         | 
| 90 | 
            +
                'Could you provide me with a thorough analysis of the regions designated <region> in this image?',
         | 
| 91 | 
            +
                'Can you provide a detailed explanation of the specific locations marked by <region> in this photo?'
         | 
| 92 | 
            +
            ]
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            GCG_QUESTIONS = [
         | 
| 95 | 
            +
                'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 96 | 
            +
                'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 97 | 
            +
                'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 98 | 
            +
                'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 99 | 
            +
                'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
         | 
| 100 | 
            +
                'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 101 | 
            +
            ]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            SEG_QUESTIONS = [
         | 
| 104 | 
            +
                "Can you segment the {class_name} in this image?",
         | 
| 105 | 
            +
                "Please segment {class_name} in this image.",
         | 
| 106 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask.",
         | 
| 107 | 
            +
                "What is {class_name} in this image? Please output segmentation mask.",
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                "Can you segment the {class_name} in this image",
         | 
| 110 | 
            +
                "Please segment {class_name} in this image",
         | 
| 111 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask",
         | 
| 112 | 
            +
                "What is {class_name} in this image? Please output segmentation mask",
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image?",
         | 
| 115 | 
            +
                "Please identify and segment the {class_name} in this image.",
         | 
| 116 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
         | 
| 117 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask?",
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image",
         | 
| 120 | 
            +
                "Please identify and segment the {class_name} in this image",
         | 
| 121 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask",
         | 
| 122 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask",
         | 
| 123 | 
            +
            ]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            ANSWER_LIST = [
         | 
| 126 | 
            +
                "It is [SEG].",
         | 
| 127 | 
            +
                "Sure, [SEG].",
         | 
| 128 | 
            +
                "Sure, it is [SEG].",
         | 
| 129 | 
            +
                "Sure, the segmentation result is [SEG].",
         | 
| 130 | 
            +
                "[SEG].",
         | 
| 131 | 
            +
            ]
         | 
    	
        projects/glamm/models/glamm.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from xtuner.registry import BUILDER
         | 
| 5 | 
            +
            from xtuner.model.utils import LoadWoInit, guess_load_checkpoint
         | 
| 6 | 
            +
            from xtuner.model.llava import LLaVAModel
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from mmengine.model import BaseModel
         | 
| 9 | 
            +
            from mmengine import print_log
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from projects.glamm.utils import prepare_inputs_labels_for_multimodal
         | 
| 12 | 
            +
            from projects.glamm.utils import DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class GLaMM(LLaVAModel):
         | 
| 16 | 
            +
                def __init__(self,
         | 
| 17 | 
            +
                             use_activation_checkpointing=True,
         | 
| 18 | 
            +
                             tokenizer=None,
         | 
| 19 | 
            +
                             grounding_encoder=None,
         | 
| 20 | 
            +
                             region_encoder=None,
         | 
| 21 | 
            +
                             loss_mask=None,
         | 
| 22 | 
            +
                             loss_dice=None,
         | 
| 23 | 
            +
                             *args, **kwargs):
         | 
| 24 | 
            +
                    super(GLaMM, self).__init__(
         | 
| 25 | 
            +
                        *args, use_activation_checkpointing=use_activation_checkpointing, **kwargs)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    self.use_activation_checkpointing = use_activation_checkpointing
         | 
| 28 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 29 | 
            +
                    self._add_special_tokens()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    self.grounding_encoder = BUILDER.build(grounding_encoder)
         | 
| 32 | 
            +
                    self.grounding_encoder.requires_grad_(False)
         | 
| 33 | 
            +
                    self.grounding_encoder.mask_decoder.requires_grad_(True)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if region_encoder is not None:
         | 
| 36 | 
            +
                        self.region_encoder = BUILDER.build(region_encoder)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    in_dim = self.config.hidden_size
         | 
| 39 | 
            +
                    out_dim = self.grounding_encoder.mask_decoder.transformer_dim
         | 
| 40 | 
            +
                    self.text_hidden_fcs = nn.Sequential(
         | 
| 41 | 
            +
                        nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True),
         | 
| 42 | 
            +
                        nn.Linear(in_dim, out_dim), nn.Dropout(0.0)
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    self.loss_mask = BUILDER.build(loss_mask)
         | 
| 46 | 
            +
                    self.loss_dice = BUILDER.build(loss_dice)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def _add_special_tokens(self):
         | 
| 49 | 
            +
                    reg_tokens = ['<im_start>', '<im_end>', '<bbox>', '<point>']
         | 
| 50 | 
            +
                    segmentation_tokens = ['[SEG]']
         | 
| 51 | 
            +
                    phrase_tokens = ['<p>', '</p>']
         | 
| 52 | 
            +
                    special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
         | 
| 53 | 
            +
                    num_new_tokens = self.tokenizer.add_tokens(
         | 
| 54 | 
            +
                        special_tokens, special_tokens=True)
         | 
| 55 | 
            +
                    if num_new_tokens > 0:
         | 
| 56 | 
            +
                        self.llm.resize_token_embeddings(len(self.tokenizer))
         | 
| 57 | 
            +
                        input_embeddings = self.llm.get_input_embeddings().weight.data
         | 
| 58 | 
            +
                        output_embeddings = self.llm.get_output_embeddings().weight.data
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
         | 
| 61 | 
            +
                            dim=0, keepdim=True)
         | 
| 62 | 
            +
                        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
         | 
| 63 | 
            +
                            dim=0, keepdim=True)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                        input_embeddings[-num_new_tokens:] = input_embeddings_avg
         | 
| 66 | 
            +
                        output_embeddings[-num_new_tokens:] = output_embeddings_avg
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
         | 
| 69 | 
            +
                    self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0]
         | 
| 70 | 
            +
                    self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0]
         | 
| 71 | 
            +
                    self.bbox_token_idx = self.tokenizer("<bbox>", add_special_tokens=False).input_ids[0]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm:
         | 
| 74 | 
            +
                        self.llm.enable_input_require_grads()
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def forward(self, data, data_samples=None, mode='loss'):
         | 
| 77 | 
            +
                    if 'pixel_values' in data:
         | 
| 78 | 
            +
                        visual_outputs = self.visual_encoder(
         | 
| 79 | 
            +
                            data['pixel_values'].to(self.visual_encoder.dtype),
         | 
| 80 | 
            +
                            output_hidden_states=True)
         | 
| 81 | 
            +
                        pixel_values = self.projector(
         | 
| 82 | 
            +
                            visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
         | 
| 83 | 
            +
                        data['pixel_values'] = pixel_values
         | 
| 84 | 
            +
                        bboxes = data.pop('bboxes', None)
         | 
| 85 | 
            +
                        if bboxes is not None:
         | 
| 86 | 
            +
                            select_hidden_state_layer = -2
         | 
| 87 | 
            +
                            num_level_reg_features = 4
         | 
| 88 | 
            +
                            mlvl_reg_features = visual_outputs.hidden_states[select_hidden_state_layer::-3]
         | 
| 89 | 
            +
                            mlvl_reg_features = mlvl_reg_features[::-1]
         | 
| 90 | 
            +
                            mlvl_reg_features = mlvl_reg_features[-num_level_reg_features:]
         | 
| 91 | 
            +
                            mlvl_reg_features = [item[:, 1:] for item in mlvl_reg_features]
         | 
| 92 | 
            +
                            mlvl_reg_features = self.region_encoder(mlvl_reg_features, bboxes)
         | 
| 93 | 
            +
                        data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
         | 
| 94 | 
            +
                        
         | 
| 95 | 
            +
                        if bboxes is not None:
         | 
| 96 | 
            +
                            inputs_embeds = data['inputs_embeds']
         | 
| 97 | 
            +
                            for i, reg_feat in enumerate(mlvl_reg_features):
         | 
| 98 | 
            +
                                reg_mask = data['new_input_ids'][i] == self.bbox_token_idx
         | 
| 99 | 
            +
                                inputs_embeds[i][reg_mask] = reg_feat
         | 
| 100 | 
            +
                            data['inputs_embeds'] = inputs_embeds
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    if mode == 'loss':
         | 
| 103 | 
            +
                        return self.compute_loss(data, data_samples)
         | 
| 104 | 
            +
                    elif mode == 'predict':
         | 
| 105 | 
            +
                        return self.predict(data, data_samples)
         | 
| 106 | 
            +
                    elif mode == 'tensor':
         | 
| 107 | 
            +
                        return self._forward(data, data_samples)
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        raise NotImplementedError
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def compute_loss(self, data, data_samples=None):
         | 
| 112 | 
            +
                    g_pixel_values = data.pop('g_pixel_values', None)
         | 
| 113 | 
            +
                    gt_masks = data.pop('masks', None)
         | 
| 114 | 
            +
                    new_input_ids = data.pop('new_input_ids', None)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    output = self.llm(output_hidden_states=True, **data)
         | 
| 117 | 
            +
                    if gt_masks is None:
         | 
| 118 | 
            +
                        return {'llm_loss': output.loss}
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    resize_list = [pixel.shape[-2:] for pixel in g_pixel_values]
         | 
| 121 | 
            +
                    ori_size_list = [mask.shape[-2:] for mask in gt_masks]
         | 
| 122 | 
            +
                    g_pixel_values = torch.stack([
         | 
| 123 | 
            +
                        self.grounding_encoder.preprocess(pixel) for pixel in g_pixel_values
         | 
| 124 | 
            +
                    ])
         | 
| 125 | 
            +
                    image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    seg_token_mask = new_input_ids == self.seg_token_idx
         | 
| 128 | 
            +
                    hidden_states = output.hidden_states
         | 
| 129 | 
            +
                    hidden_states = self.text_hidden_fcs(hidden_states[-1])
         | 
| 130 | 
            +
                    pred_embeddings = hidden_states[seg_token_mask]
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    seg_token_counts = seg_token_mask.int().sum(-1)
         | 
| 133 | 
            +
                    pred_embeddings_list = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0)
         | 
| 134 | 
            +
                    
         | 
| 135 | 
            +
                    pred_masks = self._generate_and_postprocess_masks(
         | 
| 136 | 
            +
                        pred_embeddings_list, image_embeddings, resize_list, ori_size_list)
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    bs = len(pred_masks)
         | 
| 139 | 
            +
                    loss_mask, loss_dice = 0, 0
         | 
| 140 | 
            +
                    for i in range(bs):
         | 
| 141 | 
            +
                        pred_mask = pred_masks[i]
         | 
| 142 | 
            +
                        gt_mask = gt_masks[i]
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        sam_loss_mask = self.loss_mask(pred_mask, gt_mask)
         | 
| 145 | 
            +
                        sam_loss_dice = self.loss_dice(pred_mask, gt_mask)
         | 
| 146 | 
            +
                        accuracy = torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean()
         | 
| 147 | 
            +
                        loss_mask += sam_loss_mask
         | 
| 148 | 
            +
                        loss_dice += sam_loss_dice
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
                    loss_dict = {
         | 
| 152 | 
            +
                        'loss_mask': loss_mask / bs,
         | 
| 153 | 
            +
                        'loss_dice': loss_dice / bs,
         | 
| 154 | 
            +
                        'accuracy': accuracy,
         | 
| 155 | 
            +
                        'llm_loss': output.loss,
         | 
| 156 | 
            +
                    }
         | 
| 157 | 
            +
                    return loss_dict
         | 
| 158 | 
            +
             | 
| 159 | 
            +
              
         | 
| 160 | 
            +
                def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list=None, orig_size_list=None, infer=False):
         | 
| 161 | 
            +
                    pred_masks = []
         | 
| 162 | 
            +
                    for i, pred_embedding in enumerate(pred_embeddings):
         | 
| 163 | 
            +
                        sparse_embeddings, dense_embeddings = self.grounding_encoder.prompt_encoder(
         | 
| 164 | 
            +
                            points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1)
         | 
| 165 | 
            +
                        )
         | 
| 166 | 
            +
                        sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype)
         | 
| 167 | 
            +
                        low_res_masks, _ = self.grounding_encoder.mask_decoder(
         | 
| 168 | 
            +
                            image_embeddings=image_embeddings[i].unsqueeze(0),
         | 
| 169 | 
            +
                            image_pe=self.grounding_encoder.prompt_encoder.get_dense_pe(),
         | 
| 170 | 
            +
                            sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
         | 
| 171 | 
            +
                            multimask_output=False, )
         | 
| 172 | 
            +
                        
         | 
| 173 | 
            +
                        pred_mask = self.grounding_encoder.postprocess_masks(
         | 
| 174 | 
            +
                            low_res_masks, input_size=resize_list[i], original_size=orig_size_list[i], )
         | 
| 175 | 
            +
                        pred_masks.append(pred_mask[:, 0])
         | 
| 176 | 
            +
                    return pred_masks
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                def predict(self, data):
         | 
| 179 | 
            +
                    pass
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def _forward(self, data, dta_samples=None):
         | 
| 182 | 
            +
                    outputs = self.llm(**data)
         | 
| 183 | 
            +
                    return outputs
         | 
    	
        projects/glamm/models/region_encoder.py
    ADDED
    
    | @@ -0,0 +1,359 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import ABCMeta, abstractmethod
         | 
| 2 | 
            +
            from typing import List, Optional, Tuple
         | 
| 3 | 
            +
            from torch import Tensor
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from mmcv import ops
         | 
| 11 | 
            +
            from mmcv.cnn import ConvModule, Linear
         | 
| 12 | 
            +
            from mmengine.model import BaseModule
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
         | 
| 15 | 
            +
                """Base class for RoI extractor.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Args:
         | 
| 18 | 
            +
                    roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and
         | 
| 19 | 
            +
                        arguments.
         | 
| 20 | 
            +
                    out_channels (int): Output channels of RoI layers.
         | 
| 21 | 
            +
                    featmap_strides (list[int]): Strides of input feature maps.
         | 
| 22 | 
            +
                    init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
         | 
| 23 | 
            +
                        dict], optional): Initialization config dict. Defaults to None.
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def __init__(self,
         | 
| 27 | 
            +
                             roi_layer,
         | 
| 28 | 
            +
                             out_channels: int,
         | 
| 29 | 
            +
                             featmap_strides: List[int],
         | 
| 30 | 
            +
                             init_cfg=None) -> None:
         | 
| 31 | 
            +
                    super().__init__(init_cfg=init_cfg)
         | 
| 32 | 
            +
                    self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
         | 
| 33 | 
            +
                    self.out_channels = out_channels
         | 
| 34 | 
            +
                    self.featmap_strides = featmap_strides
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                @property
         | 
| 37 | 
            +
                def num_inputs(self) -> int:
         | 
| 38 | 
            +
                    """int: Number of input feature maps."""
         | 
| 39 | 
            +
                    return len(self.featmap_strides)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def build_roi_layers(self, layer_cfg,
         | 
| 42 | 
            +
                                     featmap_strides: List[int]) -> nn.ModuleList:
         | 
| 43 | 
            +
                    """Build RoI operator to extract feature from each level feature map.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    Args:
         | 
| 46 | 
            +
                        layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
         | 
| 47 | 
            +
                            config RoI layer operation. Options are modules under
         | 
| 48 | 
            +
                            ``mmcv/ops`` such as ``RoIAlign``.
         | 
| 49 | 
            +
                        featmap_strides (list[int]): The stride of input feature map w.r.t
         | 
| 50 | 
            +
                            to the original image size, which would be used to scale RoI
         | 
| 51 | 
            +
                            coordinate (original image coordinate system) to feature
         | 
| 52 | 
            +
                            coordinate system.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    Returns:
         | 
| 55 | 
            +
                        :obj:`nn.ModuleList`: The RoI extractor modules for each level
         | 
| 56 | 
            +
                            feature map.
         | 
| 57 | 
            +
                    """
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    cfg = layer_cfg.copy()
         | 
| 60 | 
            +
                    layer_type = cfg.pop('type')
         | 
| 61 | 
            +
                    if isinstance(layer_type, str):
         | 
| 62 | 
            +
                        assert hasattr(ops, layer_type)
         | 
| 63 | 
            +
                        layer_cls = getattr(ops, layer_type)
         | 
| 64 | 
            +
                    else:
         | 
| 65 | 
            +
                        layer_cls = layer_type
         | 
| 66 | 
            +
                    roi_layers = nn.ModuleList(
         | 
| 67 | 
            +
                        [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
         | 
| 68 | 
            +
                    return roi_layers
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor:
         | 
| 71 | 
            +
                    """Scale RoI coordinates by scale factor.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    Args:
         | 
| 74 | 
            +
                        rois (Tensor): RoI (Region of Interest), shape (n, 5)
         | 
| 75 | 
            +
                        scale_factor (float): Scale factor that RoI will be multiplied by.
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    Returns:
         | 
| 78 | 
            +
                        Tensor: Scaled RoI.
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    cx = (rois[:, 1] + rois[:, 3]) * 0.5
         | 
| 82 | 
            +
                    cy = (rois[:, 2] + rois[:, 4]) * 0.5
         | 
| 83 | 
            +
                    w = rois[:, 3] - rois[:, 1]
         | 
| 84 | 
            +
                    h = rois[:, 4] - rois[:, 2]
         | 
| 85 | 
            +
                    new_w = w * scale_factor
         | 
| 86 | 
            +
                    new_h = h * scale_factor
         | 
| 87 | 
            +
                    x1 = cx - new_w * 0.5
         | 
| 88 | 
            +
                    x2 = cx + new_w * 0.5
         | 
| 89 | 
            +
                    y1 = cy - new_h * 0.5
         | 
| 90 | 
            +
                    y2 = cy + new_h * 0.5
         | 
| 91 | 
            +
                    new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
         | 
| 92 | 
            +
                    return new_rois
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                @abstractmethod
         | 
| 95 | 
            +
                def forward(self,
         | 
| 96 | 
            +
                            feats: Tuple[Tensor],
         | 
| 97 | 
            +
                            rois: Tensor,
         | 
| 98 | 
            +
                            roi_scale_factor: Optional[float] = None) -> Tensor:
         | 
| 99 | 
            +
                    """Extractor ROI feats.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    Args:
         | 
| 102 | 
            +
                        feats (Tuple[Tensor]): Multi-scale features.
         | 
| 103 | 
            +
                        rois (Tensor): RoIs with the shape (n, 5) where the first
         | 
| 104 | 
            +
                            column indicates batch id of each RoI.
         | 
| 105 | 
            +
                        roi_scale_factor (Optional[float]): RoI scale factor.
         | 
| 106 | 
            +
                            Defaults to None.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    Returns:
         | 
| 109 | 
            +
                        Tensor: RoI feature.
         | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    pass
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            class MLVLFuseModule(nn.Module):
         | 
| 115 | 
            +
                def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4):
         | 
| 116 | 
            +
                    super(MLVLFuseModule, self).__init__()
         | 
| 117 | 
            +
                    self.embed_dims = embed_dims
         | 
| 118 | 
            +
                    self.num_levels = num_levels
         | 
| 119 | 
            +
                    self.num_fuse = num_fuse
         | 
| 120 | 
            +
                    self.input_dims = input_dims
         | 
| 121 | 
            +
                    self.shuffle_channles = embed_dims // 4
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    # contains the tuple of level indices that will do the interaction
         | 
| 124 | 
            +
                    self.fuse_lvl_list = []
         | 
| 125 | 
            +
                    num_levels = self.num_levels
         | 
| 126 | 
            +
                    for lvl in range(num_levels):
         | 
| 127 | 
            +
                        top_lvl = min(lvl + 1, num_levels - 1)
         | 
| 128 | 
            +
                        dow_lvl = max(lvl - 1, 0)
         | 
| 129 | 
            +
                        tar_lvl = lvl
         | 
| 130 | 
            +
                        self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl))
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    self.remain_chs = self.embed_dims - self.shuffle_channles * 2
         | 
| 133 | 
            +
                    self._init_layers()
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def generate_coordinate(self, featmap_sizes, device='cuda'):
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
         | 
| 138 | 
            +
                    y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
         | 
| 139 | 
            +
                    y, x = torch.meshgrid(y_range, x_range)
         | 
| 140 | 
            +
                    y = y.expand([featmap_sizes[0], 1, -1, -1])
         | 
| 141 | 
            +
                    x = x.expand([featmap_sizes[0], 1, -1, -1])
         | 
| 142 | 
            +
                    coord_feat = torch.cat([x, y], 1)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    return coord_feat
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def _init_layers(self):
         | 
| 147 | 
            +
                    self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2,
         | 
| 148 | 
            +
                                                               self.embed_dims, 1)
         | 
| 149 | 
            +
                                                     for _ in range(self.num_levels)])
         | 
| 150 | 
            +
                    self.fuse_convs = nn.ModuleList()
         | 
| 151 | 
            +
                    for i in range(self.num_fuse):
         | 
| 152 | 
            +
                        self.fuse_convs.append(
         | 
| 153 | 
            +
                            ConvModule(self.embed_dims,
         | 
| 154 | 
            +
                                       self.embed_dims,
         | 
| 155 | 
            +
                                       3,
         | 
| 156 | 
            +
                                       stride=1,
         | 
| 157 | 
            +
                                       padding=3 // 2,
         | 
| 158 | 
            +
                                       conv_cfg=None,
         | 
| 159 | 
            +
                                       norm_cfg=dict(type='GN',
         | 
| 160 | 
            +
                                                     num_groups=64,
         | 
| 161 | 
            +
                                                     requires_grad=True)
         | 
| 162 | 
            +
                                       ))
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def init_weights(self):
         | 
| 165 | 
            +
                    pass
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def _single_shuffle(self, inputs, conv_module):
         | 
| 168 | 
            +
                    if not isinstance(conv_module, (nn.ModuleList, list)):
         | 
| 169 | 
            +
                        conv_module = [conv_module]
         | 
| 170 | 
            +
                    for single_conv_m in conv_module:
         | 
| 171 | 
            +
                        fused_inputs = []
         | 
| 172 | 
            +
                        for fuse_lvl_tuple in self.fuse_lvl_list:
         | 
| 173 | 
            +
                            tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple
         | 
| 174 | 
            +
                            tar_input = inputs[tar_lvl]
         | 
| 175 | 
            +
                            top_input = inputs[top_lvl]
         | 
| 176 | 
            +
                            down_input = inputs[dow_lvl]
         | 
| 177 | 
            +
                            remain = tar_input[:, :self.remain_chs]
         | 
| 178 | 
            +
                            from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:]
         | 
| 179 | 
            +
                            from_top = F.interpolate(from_top.to(torch.float32),
         | 
| 180 | 
            +
                                                     size=tar_input.shape[-2:],
         | 
| 181 | 
            +
                                                     mode='bilinear',
         | 
| 182 | 
            +
                                                     align_corners=True)
         | 
| 183 | 
            +
                            from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles]
         | 
| 184 | 
            +
                            from_down = F.interpolate(from_down.to(torch.float32),
         | 
| 185 | 
            +
                                                      size=tar_input.shape[-2:],
         | 
| 186 | 
            +
                                                      mode='bilinear',
         | 
| 187 | 
            +
                                                      align_corners=True)
         | 
| 188 | 
            +
                            fused_inputs.append(
         | 
| 189 | 
            +
                                torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1))
         | 
| 190 | 
            +
                        fused_inputs = [single_conv_m(item) for item in fused_inputs]
         | 
| 191 | 
            +
                        inputs = fused_inputs
         | 
| 192 | 
            +
                    return inputs
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def forward(self, inputs, ):
         | 
| 195 | 
            +
                    feat_size = [item.shape for item in inputs]
         | 
| 196 | 
            +
                    new_inputs = []
         | 
| 197 | 
            +
                    for feat, single_feat_size in zip(inputs, feat_size):
         | 
| 198 | 
            +
                        coord_feat = self.generate_coordinate(
         | 
| 199 | 
            +
                            single_feat_size, device=inputs[0].device)
         | 
| 200 | 
            +
                        # feat = torch.cat([feat, coord_feat], dim=1)
         | 
| 201 | 
            +
                        feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1)
         | 
| 202 | 
            +
                        new_inputs.append(feat)
         | 
| 203 | 
            +
                    inputs = new_inputs
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    inputs = [self.input_conv[lvl](item)
         | 
| 206 | 
            +
                              for lvl, item in enumerate(inputs)]
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    for conv_m in self.fuse_convs:
         | 
| 209 | 
            +
                        inputs = self._single_shuffle(inputs, [conv_m])
         | 
| 210 | 
            +
                    return inputs
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            class MlvlRoIExtractor(BaseRoIExtractor):
         | 
| 214 | 
            +
                def __init__(self,
         | 
| 215 | 
            +
                             roi_layer,
         | 
| 216 | 
            +
                             out_channels,
         | 
| 217 | 
            +
                             featmap_strides,
         | 
| 218 | 
            +
                             embed_dims=1024,
         | 
| 219 | 
            +
                             stride=1,
         | 
| 220 | 
            +
                             norm_init=True,
         | 
| 221 | 
            +
                             fuse_level=3,
         | 
| 222 | 
            +
                             finest_scale=56,
         | 
| 223 | 
            +
                             init_cfg=None):
         | 
| 224 | 
            +
                    super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels,
         | 
| 225 | 
            +
                                                           featmap_strides, init_cfg)
         | 
| 226 | 
            +
                    self.embed_dims = embed_dims
         | 
| 227 | 
            +
                    self.finest_scale = finest_scale
         | 
| 228 | 
            +
                    self.fuse_level = fuse_level
         | 
| 229 | 
            +
                    self.norm_init = norm_init
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    self.pconvs = nn.ModuleList(
         | 
| 232 | 
            +
                        nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1)
         | 
| 233 | 
            +
                        for _ in range(self.fuse_level))
         | 
| 234 | 
            +
                    self.pos_embedd = nn.Sequential(
         | 
| 235 | 
            +
                        nn.Linear(4, 256),
         | 
| 236 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 237 | 
            +
                        nn.LayerNorm(256),
         | 
| 238 | 
            +
                        nn.Linear(256, 1024),
         | 
| 239 | 
            +
                        nn.ReLU(inplace=True),
         | 
| 240 | 
            +
                        nn.LayerNorm(1024),
         | 
| 241 | 
            +
                    )
         | 
| 242 | 
            +
                    self.updims = nn.Linear(1024, 4096)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    self.flatten_linear = nn.Linear(
         | 
| 245 | 
            +
                        self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024)
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    self.norm_init_weights()
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                #  self.dtype = torch.float32
         | 
| 250 | 
            +
                def norm_init_weights(self):
         | 
| 251 | 
            +
                    pass
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def forward(self, feats, rois, roi_scale_factor=None):
         | 
| 254 | 
            +
                    """Forward function."""
         | 
| 255 | 
            +
                    num_imgs = len(rois)
         | 
| 256 | 
            +
                    # feats = [item for item in feats]
         | 
| 257 | 
            +
                    batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype)
         | 
| 258 | 
            +
                    pos_embedd = self.pos_embedd(batch_rois)
         | 
| 259 | 
            +
                    out_size = self.roi_layers[0].output_size
         | 
| 260 | 
            +
                    num_levels = len(feats)
         | 
| 261 | 
            +
                    if feats[0].dim() == 3:
         | 
| 262 | 
            +
                        h = w = int(math.sqrt(feats[0].shape[1]))
         | 
| 263 | 
            +
                        assert h == 16
         | 
| 264 | 
            +
                        assert w == 16
         | 
| 265 | 
            +
                        b, c = feats[0].shape[0], feats[0].shape[-1]
         | 
| 266 | 
            +
                        feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2)
         | 
| 267 | 
            +
                                 for item in feats]
         | 
| 268 | 
            +
                    new_rois = []
         | 
| 269 | 
            +
                    for img_id, single_img_roi in enumerate(rois):
         | 
| 270 | 
            +
                        # rescale to original img scale
         | 
| 271 | 
            +
                        single_img_roi = single_img_roi * 224
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                        roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id
         | 
| 274 | 
            +
                        single_img_roi = torch.cat(
         | 
| 275 | 
            +
                            [roi_img_id[:, None], single_img_roi], dim=1)
         | 
| 276 | 
            +
                        new_rois.append(single_img_roi)
         | 
| 277 | 
            +
                    rois = torch.cat(new_rois)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    roi_feats = feats[0].new_zeros(self.fuse_level,
         | 
| 280 | 
            +
                                                   rois.size(0), self.out_channels, *out_size)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    for i in range(num_levels):
         | 
| 283 | 
            +
                        if len(rois) > 0:
         | 
| 284 | 
            +
                            rois_ = rois
         | 
| 285 | 
            +
                            ori_dtype = feats[i].dtype
         | 
| 286 | 
            +
                            roi_feats_t = self.roi_layers[i](feats[i].to(
         | 
| 287 | 
            +
                                torch.float32), rois_.to(torch.float32))
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                            roi_feats[i] = roi_feats_t.to(ori_dtype)
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        else:
         | 
| 292 | 
            +
                            roi_feats += sum(
         | 
| 293 | 
            +
                                x.view(-1)[0]
         | 
| 294 | 
            +
                                for x in self.parameters()) * 0. + feats[i].sum() * 0.
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    fuse_roi_feats = []
         | 
| 297 | 
            +
                    for i in range(self.fuse_level):
         | 
| 298 | 
            +
                        fuse_roi_feats.append(self.pconvs[i](roi_feats[i]))
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    fuse_roi_feats = sum(fuse_roi_feats)
         | 
| 301 | 
            +
                    fuse_roi_feats = F.relu(fuse_roi_feats)
         | 
| 302 | 
            +
                    fuse_roi_feats = fuse_roi_feats.flatten(1, -1)
         | 
| 303 | 
            +
                    fuse_roi_feats = self.flatten_linear(fuse_roi_feats)
         | 
| 304 | 
            +
                    fuse_roi_feats = fuse_roi_feats + pos_embedd
         | 
| 305 | 
            +
                    fuse_roi_feats = self.updims(fuse_roi_feats)
         | 
| 306 | 
            +
                    query_feats = []
         | 
| 307 | 
            +
                    for i in range(num_imgs):
         | 
| 308 | 
            +
                        mask = rois[:, 0] == i
         | 
| 309 | 
            +
                        query_feats.append(fuse_roi_feats[mask])
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    return query_feats
         | 
| 312 | 
            +
             | 
| 313 | 
            +
             | 
| 314 | 
            +
            class MLVLROIQueryModule(nn.Module):
         | 
| 315 | 
            +
                def __init__(self, embed_dims=1024, out_dims=4096,
         | 
| 316 | 
            +
                             num_levels=3):
         | 
| 317 | 
            +
                    super(MLVLROIQueryModule, self).__init__()
         | 
| 318 | 
            +
                    self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims,
         | 
| 319 | 
            +
                                                    embed_dims=embed_dims,
         | 
| 320 | 
            +
                                                    num_levels=num_levels,
         | 
| 321 | 
            +
                                                    num_fuse=5)
         | 
| 322 | 
            +
                    strids = [14 / 8, 14 / 4, 14 / 2, 14]
         | 
| 323 | 
            +
                    assert len(strids) == num_levels
         | 
| 324 | 
            +
                    bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign',
         | 
| 325 | 
            +
                                                             output_size=14,
         | 
| 326 | 
            +
                                                             sampling_ratio=2),
         | 
| 327 | 
            +
                                              out_channels=embed_dims,
         | 
| 328 | 
            +
                                              embed_dims=embed_dims,
         | 
| 329 | 
            +
                                              fuse_level=num_levels,
         | 
| 330 | 
            +
                                              featmap_strides=strids)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                def forward(self, mlvl_feats, bboxes):
         | 
| 335 | 
            +
                    if mlvl_feats[0].dim() == 3:
         | 
| 336 | 
            +
                        h = w = int(math.sqrt(mlvl_feats[0].shape[1]))
         | 
| 337 | 
            +
                        assert h == 24
         | 
| 338 | 
            +
                        assert w == 24
         | 
| 339 | 
            +
                        b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1]
         | 
| 340 | 
            +
                        mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats]
         | 
| 341 | 
            +
                    base_shape = mlvl_feats[0].shape[-2:]
         | 
| 342 | 
            +
                    num_level = len(mlvl_feats)
         | 
| 343 | 
            +
                    to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level)
         | 
| 344 | 
            +
                                for level in range(num_level)]
         | 
| 345 | 
            +
                    to_shape = to_shape[::-1]
         | 
| 346 | 
            +
                    for level in range(num_level):
         | 
| 347 | 
            +
                        feat = mlvl_feats[level]
         | 
| 348 | 
            +
                        shape = to_shape[level]
         | 
| 349 | 
            +
                        # feat = feat
         | 
| 350 | 
            +
                        # mlvl_feats[level] = F.interpolate(feat, size=shape, mode='bilinear', align_corners=True)
         | 
| 351 | 
            +
                        # todo: temporary fix for "upsample_bilinear2d_out_frame" not implemented for 'BFloat16'
         | 
| 352 | 
            +
                        feat = feat.to(torch.float32)
         | 
| 353 | 
            +
                        mlvl_feats[level] = F.interpolate(
         | 
| 354 | 
            +
                            feat, size=shape, mode='bilinear', align_corners=True)
         | 
| 355 | 
            +
                        mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                    mlvl_feats = self.mlvl_fuse(mlvl_feats)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    return self.roi_align(mlvl_feats, bboxes)
         | 
    	
        projects/glamm/utils.py
    ADDED
    
    | @@ -0,0 +1,280 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enum import Enum
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.distributed as dist
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from transformers import PreTrainedModel
         | 
| 8 | 
            +
            from typing import List, Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            IGNORE_INDEX = -100
         | 
| 12 | 
            +
            IMAGE_TOKEN_INDEX = -200
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            DEFAULT_EOS_TOKEN = '</s>'
         | 
| 15 | 
            +
            DEFAULT_BOS_TOKEN = '<s>'
         | 
| 16 | 
            +
            DEFAULT_UNK_TOKEN = '<unk>'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            DEFAULT_IMAGE_TOKEN = "<image>"
         | 
| 19 | 
            +
            DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
         | 
| 20 | 
            +
            DEFAULT_IM_START_TOKEN = "<im_start>"
         | 
| 21 | 
            +
            DEFAULT_IM_END_TOKEN = "<im_end>"
         | 
| 22 | 
            +
            DEFAULT_BBOX_TOKEN = "<bbox>"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            # Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99  # noqa: E501
         | 
| 27 | 
            +
            def prepare_inputs_labels_for_multimodal(
         | 
| 28 | 
            +
                    llm: PreTrainedModel,
         | 
| 29 | 
            +
                    input_ids: torch.LongTensor = None,
         | 
| 30 | 
            +
                    position_ids: Optional[torch.LongTensor] = None,
         | 
| 31 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 32 | 
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         | 
| 33 | 
            +
                    labels: Optional[torch.LongTensor] = None,
         | 
| 34 | 
            +
                    pixel_values: Optional[torch.FloatTensor] = None,
         | 
| 35 | 
            +
                    **kwargs):
         | 
| 36 | 
            +
                if pixel_values is None:
         | 
| 37 | 
            +
                    kwargs.update({
         | 
| 38 | 
            +
                        'input_ids': input_ids,
         | 
| 39 | 
            +
                        'position_ids': position_ids,
         | 
| 40 | 
            +
                        'attention_mask': attention_mask,
         | 
| 41 | 
            +
                        'past_key_values': past_key_values,
         | 
| 42 | 
            +
                        'inputs_embeds': None,
         | 
| 43 | 
            +
                        'labels': labels
         | 
| 44 | 
            +
                    })
         | 
| 45 | 
            +
                    return kwargs
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                _labels = labels
         | 
| 48 | 
            +
                _position_ids = position_ids
         | 
| 49 | 
            +
                _attention_mask = attention_mask
         | 
| 50 | 
            +
                if attention_mask is None:
         | 
| 51 | 
            +
                    attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                    attention_mask = attention_mask.bool()
         | 
| 54 | 
            +
                if position_ids is None:
         | 
| 55 | 
            +
                    position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
         | 
| 56 | 
            +
                if labels is None:
         | 
| 57 | 
            +
                    labels = torch.full_like(input_ids, IGNORE_INDEX)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # remove the padding using attention_mask -- TODO: double check
         | 
| 60 | 
            +
                input_ids = [
         | 
| 61 | 
            +
                    cur_input_ids[cur_attention_mask]
         | 
| 62 | 
            +
                    for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
         | 
| 63 | 
            +
                ]
         | 
| 64 | 
            +
                labels = [
         | 
| 65 | 
            +
                    cur_labels[cur_attention_mask]
         | 
| 66 | 
            +
                    for cur_labels, cur_attention_mask in zip(labels, attention_mask)
         | 
| 67 | 
            +
                ]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                new_inputs_embeds = []
         | 
| 70 | 
            +
                new_labels = []
         | 
| 71 | 
            +
                new_input_ids = []
         | 
| 72 | 
            +
                cur_image_idx = 0
         | 
| 73 | 
            +
                for batch_idx, cur_input_ids in enumerate(input_ids):
         | 
| 74 | 
            +
                    num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
         | 
| 75 | 
            +
                    if num_images == 0:
         | 
| 76 | 
            +
                        cur_pixel_values = pixel_values[cur_image_idx]
         | 
| 77 | 
            +
                        cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
         | 
| 78 | 
            +
                        cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
         | 
| 79 | 
            +
                        new_inputs_embeds.append(cur_inputs_embeds)
         | 
| 80 | 
            +
                        new_labels.append(labels[batch_idx])
         | 
| 81 | 
            +
                        new_input_ids.append(cur_input_ids)
         | 
| 82 | 
            +
                        cur_image_idx += 1
         | 
| 83 | 
            +
                        continue
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
         | 
| 86 | 
            +
                    cur_input_ids_noim = []
         | 
| 87 | 
            +
                    cur_labels = labels[batch_idx]
         | 
| 88 | 
            +
                    cur_labels_noim = []
         | 
| 89 | 
            +
                    for i in range(len(image_token_indices) - 1):
         | 
| 90 | 
            +
                        cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
         | 
| 91 | 
            +
                        cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    split_sizes = [x.shape[0] for x in cur_labels_noim]
         | 
| 94 | 
            +
                    cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim))
         | 
| 95 | 
            +
                    cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0)
         | 
| 96 | 
            +
                    cur_new_inputs_embeds = []
         | 
| 97 | 
            +
                    cur_new_labels = []
         | 
| 98 | 
            +
                    cur_new_input_ids = []
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    for i in range(num_images + 1):
         | 
| 101 | 
            +
                        cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
         | 
| 102 | 
            +
                        cur_new_labels.append(cur_labels_noim[i])
         | 
| 103 | 
            +
                        cur_new_input_ids.append(cur_input_ids_noim[i])
         | 
| 104 | 
            +
                        if i < num_images:
         | 
| 105 | 
            +
                            cur_pixel_values = pixel_values[cur_image_idx]
         | 
| 106 | 
            +
                            cur_image_idx += 1
         | 
| 107 | 
            +
                            cur_new_inputs_embeds.append(cur_pixel_values)
         | 
| 108 | 
            +
                            cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
         | 
| 109 | 
            +
                            cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype))
         | 
| 110 | 
            +
                    
         | 
| 111 | 
            +
                    cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
         | 
| 112 | 
            +
                    cur_new_labels = torch.cat(cur_new_labels)
         | 
| 113 | 
            +
                    cur_new_input_ids = torch.cat(cur_new_input_ids)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    new_inputs_embeds.append(cur_new_inputs_embeds)
         | 
| 116 | 
            +
                    new_labels.append(cur_new_labels)
         | 
| 117 | 
            +
                    new_input_ids.append(cur_new_input_ids)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Combine them
         | 
| 120 | 
            +
                max_len = max(x.shape[0] for x in new_inputs_embeds)
         | 
| 121 | 
            +
                batch_size = len(new_inputs_embeds)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                new_inputs_embeds_padded = []
         | 
| 124 | 
            +
                new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX,  dtype=new_labels[0].dtype,  device=new_labels[0].device)
         | 
| 125 | 
            +
                new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX,  dtype=new_input_ids[0].dtype,  device=new_input_ids[0].device)
         | 
| 126 | 
            +
                attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
         | 
| 127 | 
            +
                position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)):
         | 
| 130 | 
            +
                    cur_len = cur_new_embed.shape[0]
         | 
| 131 | 
            +
                    new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,  device=cur_new_embed.device)), dim=0))
         | 
| 132 | 
            +
                    if cur_len > 0:
         | 
| 133 | 
            +
                        new_labels_padded[i, :cur_len] = cur_new_labels
         | 
| 134 | 
            +
                        new_input_ids_padded[i, :cur_len] = cur_new_input_ids
         | 
| 135 | 
            +
                        attention_mask[i, :cur_len] = True
         | 
| 136 | 
            +
                        position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                if _labels is None:
         | 
| 141 | 
            +
                    new_labels = None
         | 
| 142 | 
            +
                else:
         | 
| 143 | 
            +
                    new_labels = new_labels_padded
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                new_input_ids = new_input_ids_padded
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                if _attention_mask is None:
         | 
| 148 | 
            +
                    attention_mask = None
         | 
| 149 | 
            +
                else:
         | 
| 150 | 
            +
                    attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                if _position_ids is None:
         | 
| 153 | 
            +
                    position_ids = None
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                kwargs.update({
         | 
| 156 | 
            +
                    'input_ids': None,
         | 
| 157 | 
            +
                    'position_ids': position_ids,
         | 
| 158 | 
            +
                    'attention_mask': attention_mask,
         | 
| 159 | 
            +
                    'past_key_values': past_key_values,
         | 
| 160 | 
            +
                    'inputs_embeds': new_inputs_embeds,
         | 
| 161 | 
            +
                    'labels': new_labels,
         | 
| 162 | 
            +
                    'new_input_ids': new_input_ids
         | 
| 163 | 
            +
                })
         | 
| 164 | 
            +
                return kwargs
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            class Summary(Enum):
         | 
| 167 | 
            +
                NONE = 0
         | 
| 168 | 
            +
                AVERAGE = 1
         | 
| 169 | 
            +
                SUM = 2
         | 
| 170 | 
            +
                COUNT = 3
         | 
| 171 | 
            +
             | 
| 172 | 
            +
             | 
| 173 | 
            +
            class AverageMeter(object):
         | 
| 174 | 
            +
                """Computes and stores the average and current value"""
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
         | 
| 177 | 
            +
                    self.name = name
         | 
| 178 | 
            +
                    self.fmt = fmt
         | 
| 179 | 
            +
                    self.summary_type = summary_type
         | 
| 180 | 
            +
                    self.reset()
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def reset(self):
         | 
| 183 | 
            +
                    self.val = 0
         | 
| 184 | 
            +
                    self.avg = 0
         | 
| 185 | 
            +
                    self.sum = 0
         | 
| 186 | 
            +
                    self.count = 0
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def update(self, val, n=1):
         | 
| 189 | 
            +
                    self.val = val
         | 
| 190 | 
            +
                    self.sum += val * n
         | 
| 191 | 
            +
                    self.count += n
         | 
| 192 | 
            +
                    self.avg = self.sum / self.count
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def all_reduce(self):
         | 
| 195 | 
            +
                    device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 196 | 
            +
                    if isinstance(self.sum, np.ndarray):
         | 
| 197 | 
            +
                        total = torch.tensor(
         | 
| 198 | 
            +
                            self.sum.tolist()
         | 
| 199 | 
            +
                            + [
         | 
| 200 | 
            +
                                self.count,
         | 
| 201 | 
            +
                            ],
         | 
| 202 | 
            +
                            dtype=torch.float32,
         | 
| 203 | 
            +
                            device=device,
         | 
| 204 | 
            +
                        )
         | 
| 205 | 
            +
                    else:
         | 
| 206 | 
            +
                        total = torch.tensor(
         | 
| 207 | 
            +
                            [self.sum, self.count], dtype=torch.float32, device=device
         | 
| 208 | 
            +
                        )
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
         | 
| 211 | 
            +
                    if total.shape[0] > 2:
         | 
| 212 | 
            +
                        self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
         | 
| 213 | 
            +
                    else:
         | 
| 214 | 
            +
                        self.sum, self.count = total.tolist()
         | 
| 215 | 
            +
                    self.avg = self.sum / (self.count + 1e-5)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def __str__(self):
         | 
| 218 | 
            +
                    fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
         | 
| 219 | 
            +
                    return fmtstr.format(**self.__dict__)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def summary(self):
         | 
| 222 | 
            +
                    fmtstr = ""
         | 
| 223 | 
            +
                    if self.summary_type is Summary.NONE:
         | 
| 224 | 
            +
                        fmtstr = ""
         | 
| 225 | 
            +
                    elif self.summary_type is Summary.AVERAGE:
         | 
| 226 | 
            +
                        fmtstr = "{name} {avg:.3f}"
         | 
| 227 | 
            +
                    elif self.summary_type is Summary.SUM:
         | 
| 228 | 
            +
                        fmtstr = "{name} {sum:.3f}"
         | 
| 229 | 
            +
                    elif self.summary_type is Summary.COUNT:
         | 
| 230 | 
            +
                        fmtstr = "{name} {count:.3f}"
         | 
| 231 | 
            +
                    else:
         | 
| 232 | 
            +
                        raise ValueError("invalid summary type %r" % self.summary_type)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    return fmtstr.format(**self.__dict__)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
             | 
| 237 | 
            +
            def intersectionAndUnionGPU(output, target, K, ignore_index=255):
         | 
| 238 | 
            +
                # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
         | 
| 239 | 
            +
                assert output.dim() in [1, 2, 3]
         | 
| 240 | 
            +
                assert output.shape == target.shape
         | 
| 241 | 
            +
                output = output.view(-1)
         | 
| 242 | 
            +
                target = target.view(-1)
         | 
| 243 | 
            +
                output[target == ignore_index] = ignore_index
         | 
| 244 | 
            +
                intersection = output[output == target]
         | 
| 245 | 
            +
                area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
         | 
| 246 | 
            +
                area_output = torch.histc(output, bins=K, min=0, max=K - 1)
         | 
| 247 | 
            +
                area_target = torch.histc(target, bins=K, min=0, max=K - 1)
         | 
| 248 | 
            +
                area_union = area_output + area_target - area_intersection
         | 
| 249 | 
            +
                return area_intersection, area_union, area_target
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            class ProgressMeter(object):
         | 
| 253 | 
            +
                def __init__(self, num_batches, meters, prefix=""):
         | 
| 254 | 
            +
                    self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
         | 
| 255 | 
            +
                    self.meters = meters
         | 
| 256 | 
            +
                    self.prefix = prefix
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def display(self, batch):
         | 
| 259 | 
            +
                    entries = [self.prefix + self.batch_fmtstr.format(batch)]
         | 
| 260 | 
            +
                    entries += [str(meter) for meter in self.meters]
         | 
| 261 | 
            +
                    print("\t".join(entries))
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def display_summary(self):
         | 
| 264 | 
            +
                    entries = [" *"]
         | 
| 265 | 
            +
                    entries += [meter.summary() for meter in self.meters]
         | 
| 266 | 
            +
                    print(" ".join(entries))
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def _get_batch_fmtstr(self, num_batches):
         | 
| 269 | 
            +
                    num_digits = len(str(num_batches // 1))
         | 
| 270 | 
            +
                    fmt = "{:" + str(num_digits) + "d}"
         | 
| 271 | 
            +
                    return "[" + fmt + "/" + fmt.format(num_batches) + "]"
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            def dict_to_cuda(input_dict):
         | 
| 275 | 
            +
                for k, v in input_dict.items():
         | 
| 276 | 
            +
                    if isinstance(input_dict[k], torch.Tensor):
         | 
| 277 | 
            +
                        input_dict[k] = v.cuda(non_blocking=True)
         | 
| 278 | 
            +
                    elif isinstance(v, list) and len(v) > 0:
         | 
| 279 | 
            +
                        input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v]
         | 
| 280 | 
            +
                return input_dict
         | 
    	
        projects/llava_sam2/configs/sa2va_4b.py
    ADDED
    
    | @@ -0,0 +1,548 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
         | 
| 2 | 
            +
                                        LoggerHook, ParamSchedulerHook)
         | 
| 3 | 
            +
            from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
         | 
| 4 | 
            +
            from torch.optim import AdamW
         | 
| 5 | 
            +
            from transformers import AutoTokenizer
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from xtuner.dataset import ConcatDataset
         | 
| 8 | 
            +
            from xtuner.dataset.samplers import LengthGroupedSampler
         | 
| 9 | 
            +
            from xtuner.engine.hooks import DatasetInfoHook
         | 
| 10 | 
            +
            from xtuner.engine.runner import TrainLoop
         | 
| 11 | 
            +
            from xtuner.utils import PROMPT_TEMPLATE
         | 
| 12 | 
            +
            from xtuner.dataset.map_fns import template_map_fn_factory
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
         | 
| 15 | 
            +
            from peft import LoraConfig
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from projects.llava_sam2.models.internvl import InternVL_Slowfast
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
         | 
| 20 | 
            +
            from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
         | 
| 21 | 
            +
            from projects.llava_sam2.datasets import VideoChatUniViDataset
         | 
| 22 | 
            +
            from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
         | 
| 23 | 
            +
            from projects.llava_sam2.datasets import LLaVADataset
         | 
| 24 | 
            +
            from projects.llava_sam2.datasets import ReferSegmDataset
         | 
| 25 | 
            +
            from projects.llava_sam2.models.preprocess.image_resize import DirectResize
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            #######################################################################
         | 
| 28 | 
            +
            #                          PART 1  Settings                           #
         | 
| 29 | 
            +
            #######################################################################
         | 
| 30 | 
            +
            # Model
         | 
| 31 | 
            +
            path = './pretrained/InternVL2_5-4B'
         | 
| 32 | 
            +
            pretrained_pth = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Data
         | 
| 35 | 
            +
            prompt_template = PROMPT_TEMPLATE.phi3_chat
         | 
| 36 | 
            +
            max_length = 8192
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # Scheduler & Optimizer
         | 
| 39 | 
            +
            batch_size = 2  # per_device
         | 
| 40 | 
            +
            accumulative_counts = 4
         | 
| 41 | 
            +
            dataloader_num_workers = 4
         | 
| 42 | 
            +
            max_epochs = 1
         | 
| 43 | 
            +
            optim_type = AdamW
         | 
| 44 | 
            +
            # official 1024 -> 4e-5
         | 
| 45 | 
            +
            # lr = 1e-6
         | 
| 46 | 
            +
            lr = 4e-5
         | 
| 47 | 
            +
            betas = (0.9, 0.999)
         | 
| 48 | 
            +
            weight_decay = 0.05
         | 
| 49 | 
            +
            max_norm = 1  # grad clip
         | 
| 50 | 
            +
            warmup_ratio = 0.05
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Save
         | 
| 53 | 
            +
            save_steps = 1000
         | 
| 54 | 
            +
            save_total_limit = 2  # Maximum checkpoints to keep (-1 means unlimited)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            tokenizer = dict(
         | 
| 59 | 
            +
                type=AutoTokenizer.from_pretrained,
         | 
| 60 | 
            +
                pretrained_model_name_or_path=path,
         | 
| 61 | 
            +
                trust_remote_code=True,
         | 
| 62 | 
            +
                padding_side='right')
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            extra_image_processor = dict(
         | 
| 65 | 
            +
                type=DirectResize,
         | 
| 66 | 
            +
                target_length=1024,
         | 
| 67 | 
            +
            )
         | 
| 68 | 
            +
            #######################################################################
         | 
| 69 | 
            +
            #            PART 2  Model & Tokenizer & Image Processor              #
         | 
| 70 | 
            +
            #######################################################################
         | 
| 71 | 
            +
            model = dict(
         | 
| 72 | 
            +
                type=VideoLLaVASAMModel_zero3,
         | 
| 73 | 
            +
                special_tokens=special_tokens,
         | 
| 74 | 
            +
                frozen_sam2_decoder=False,
         | 
| 75 | 
            +
                mllm=dict(
         | 
| 76 | 
            +
                    type=InternVL_Slowfast,
         | 
| 77 | 
            +
                    model_path=path,
         | 
| 78 | 
            +
                    freeze_llm=True,
         | 
| 79 | 
            +
                    freeze_visual_encoder=True,
         | 
| 80 | 
            +
                    llm_lora=dict(
         | 
| 81 | 
            +
                        type=LoraConfig,
         | 
| 82 | 
            +
                        r=128,
         | 
| 83 | 
            +
                        lora_alpha=256,
         | 
| 84 | 
            +
                        lora_dropout=0.05,
         | 
| 85 | 
            +
                        bias='none',
         | 
| 86 | 
            +
                        task_type='CAUSAL_LM'),
         | 
| 87 | 
            +
                    special_tokens=special_tokens,
         | 
| 88 | 
            +
                ),
         | 
| 89 | 
            +
                tokenizer=tokenizer,
         | 
| 90 | 
            +
                grounding_encoder=dict(
         | 
| 91 | 
            +
                    type=SAM2TrainRunner,
         | 
| 92 | 
            +
                ),
         | 
| 93 | 
            +
                loss_mask=dict(
         | 
| 94 | 
            +
                    type=CrossEntropyLoss,
         | 
| 95 | 
            +
                    use_sigmoid=True,
         | 
| 96 | 
            +
                    reduction='mean',
         | 
| 97 | 
            +
                    loss_weight=2.0),
         | 
| 98 | 
            +
                loss_dice=dict(
         | 
| 99 | 
            +
                    type=DiceLoss,
         | 
| 100 | 
            +
                    use_sigmoid=True,
         | 
| 101 | 
            +
                    activate=True,
         | 
| 102 | 
            +
                    reduction='mean',
         | 
| 103 | 
            +
                    naive_dice=True,
         | 
| 104 | 
            +
                    eps=1.0,
         | 
| 105 | 
            +
                    loss_weight=0.5),
         | 
| 106 | 
            +
                pretrained_pth=pretrained_pth,
         | 
| 107 | 
            +
                loss_sample_points=True,
         | 
| 108 | 
            +
                # loss_sample_points=False,
         | 
| 109 | 
            +
                bs=batch_size,
         | 
| 110 | 
            +
            )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            #######################################################################
         | 
| 113 | 
            +
            #                      PART 3  Dataset & Dataloader                   #
         | 
| 114 | 
            +
            #######################################################################
         | 
| 115 | 
            +
             | 
| 116 | 
            +
             | 
| 117 | 
            +
            VIDEO_DATAS = './data/video_datas/'
         | 
| 118 | 
            +
            IMG_DATAS = './data/image_datas/'
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            ############### video res
         | 
| 121 | 
            +
            data_root_revos = './data/video_datas/revos/'
         | 
| 122 | 
            +
            video_revos_image_folder = data_root_revos
         | 
| 123 | 
            +
            video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
         | 
| 124 | 
            +
            video_revos_mask_file = data_root_revos + 'mask_dict.json'
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            data_root_mevis = './data/video_datas/mevis/train/'
         | 
| 127 | 
            +
            video_mevis_image_folder = data_root_mevis + 'JPEGImages'
         | 
| 128 | 
            +
            video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
         | 
| 129 | 
            +
            video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            data_root_refytvos = './data/video_datas/rvos/'
         | 
| 132 | 
            +
            video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
         | 
| 133 | 
            +
            video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
         | 
| 134 | 
            +
            video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
         | 
| 135 | 
            +
             | 
| 136 | 
            +
            video_revos_dataset = dict(
         | 
| 137 | 
            +
                type=VideoReVOSDataset,
         | 
| 138 | 
            +
                image_folder=video_revos_image_folder,
         | 
| 139 | 
            +
                expression_file=video_revos_expression_file,
         | 
| 140 | 
            +
                mask_file=video_revos_mask_file,
         | 
| 141 | 
            +
                tokenizer=tokenizer,
         | 
| 142 | 
            +
                template_map_fn=dict(
         | 
| 143 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 144 | 
            +
                max_length=max_length,
         | 
| 145 | 
            +
                lazy=True,
         | 
| 146 | 
            +
                repeats=10,
         | 
| 147 | 
            +
                special_tokens=special_tokens,
         | 
| 148 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 149 | 
            +
                sampled_frames=5,
         | 
| 150 | 
            +
            )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            video_mevis_dataset = dict(
         | 
| 153 | 
            +
                type=VideoMeVISDataset,
         | 
| 154 | 
            +
                image_folder=video_mevis_image_folder,
         | 
| 155 | 
            +
                expression_file=video_mevis_expression_file,
         | 
| 156 | 
            +
                mask_file=video_mevis_mask_file,
         | 
| 157 | 
            +
                tokenizer=tokenizer,
         | 
| 158 | 
            +
                template_map_fn=dict(
         | 
| 159 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 160 | 
            +
                max_length=max_length,
         | 
| 161 | 
            +
                lazy=True,
         | 
| 162 | 
            +
                repeats=4,
         | 
| 163 | 
            +
                special_tokens=special_tokens,
         | 
| 164 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 165 | 
            +
                sampled_frames=5,
         | 
| 166 | 
            +
            )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            video_refytvos_dataset = dict(
         | 
| 169 | 
            +
                type=VideoRefYoutubeVOSDataset,
         | 
| 170 | 
            +
                image_folder=video_refytvos_image_folder,
         | 
| 171 | 
            +
                expression_file=video_refytvos_expression_file,
         | 
| 172 | 
            +
                mask_file=video_refytvos_mask_file,
         | 
| 173 | 
            +
                tokenizer=tokenizer,
         | 
| 174 | 
            +
                template_map_fn=dict(
         | 
| 175 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 176 | 
            +
                max_length=max_length,
         | 
| 177 | 
            +
                lazy=True,
         | 
| 178 | 
            +
                repeats=4,
         | 
| 179 | 
            +
                special_tokens=special_tokens,
         | 
| 180 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 181 | 
            +
                sampled_frames=5,
         | 
| 182 | 
            +
            )
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            ################### Video chat
         | 
| 185 | 
            +
            data_root_video_chatunivi = VIDEO_DATAS + 'video_vlm/video_chat/'
         | 
| 186 | 
            +
            video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
         | 
| 187 | 
            +
            video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            video_qa_dataset = dict(
         | 
| 190 | 
            +
                type=VideoChatUniViDataset,
         | 
| 191 | 
            +
                image_folder=video_chatunivi_image_folder,
         | 
| 192 | 
            +
                json_file=video_chatunivi_json_file,
         | 
| 193 | 
            +
                tokenizer=tokenizer,
         | 
| 194 | 
            +
                template_map_fn=dict(
         | 
| 195 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 196 | 
            +
                max_length=max_length,
         | 
| 197 | 
            +
                lazy=True,
         | 
| 198 | 
            +
                repeats=1,
         | 
| 199 | 
            +
                special_tokens=special_tokens,
         | 
| 200 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 201 | 
            +
                sampled_frames=5,
         | 
| 202 | 
            +
            )
         | 
| 203 | 
            +
             | 
| 204 | 
            +
            ################## image chat
         | 
| 205 | 
            +
            llava_vqa_dataset = dict(
         | 
| 206 | 
            +
                type=LLaVADataset,
         | 
| 207 | 
            +
                tokenizer=tokenizer,
         | 
| 208 | 
            +
                data_path='data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
         | 
| 209 | 
            +
                prompt_template=prompt_template,
         | 
| 210 | 
            +
                special_tokens=special_tokens,
         | 
| 211 | 
            +
                image_folder='data/llava_data/llava_images/',
         | 
| 212 | 
            +
            )
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            ################## image res
         | 
| 215 | 
            +
            refcoco_segm_dataset=dict(
         | 
| 216 | 
            +
                type=ReferSegmDataset,
         | 
| 217 | 
            +
                tokenizer=tokenizer,
         | 
| 218 | 
            +
                special_tokens=special_tokens,
         | 
| 219 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 220 | 
            +
                data_root='data/ref_seg/refcoco',
         | 
| 221 | 
            +
                data_prefix=dict(img_path='coco2014/train2014/'),
         | 
| 222 | 
            +
                ann_file='instances.json',
         | 
| 223 | 
            +
                split_file='refs(unc).p',
         | 
| 224 | 
            +
                prompt_template=prompt_template,
         | 
| 225 | 
            +
                num_classes_per_sample=5,
         | 
| 226 | 
            +
                max_length=max_length,
         | 
| 227 | 
            +
            )
         | 
| 228 | 
            +
            refcoco_plus_segm_dataset=dict(
         | 
| 229 | 
            +
                type=ReferSegmDataset,
         | 
| 230 | 
            +
                tokenizer=tokenizer,
         | 
| 231 | 
            +
                special_tokens=special_tokens,
         | 
| 232 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 233 | 
            +
                data_root='data/ref_seg/refcoco+',
         | 
| 234 | 
            +
                data_prefix=dict(img_path='coco2014/train2014/'),
         | 
| 235 | 
            +
                ann_file='instances.json',
         | 
| 236 | 
            +
                split_file='refs(unc).p',
         | 
| 237 | 
            +
                prompt_template=prompt_template,
         | 
| 238 | 
            +
                num_classes_per_sample=5,
         | 
| 239 | 
            +
                max_length=max_length,
         | 
| 240 | 
            +
            )
         | 
| 241 | 
            +
            refcocog_segm_dataset=dict(
         | 
| 242 | 
            +
                type=ReferSegmDataset,
         | 
| 243 | 
            +
                tokenizer=tokenizer,
         | 
| 244 | 
            +
                special_tokens=special_tokens,
         | 
| 245 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 246 | 
            +
                data_root='data/ref_seg/refcocog',
         | 
| 247 | 
            +
                data_prefix=dict(img_path='coco2014/train2014/'),
         | 
| 248 | 
            +
                ann_file='instances.json',
         | 
| 249 | 
            +
                split_file='refs(umd).p',
         | 
| 250 | 
            +
                prompt_template=prompt_template,
         | 
| 251 | 
            +
                num_classes_per_sample=5,
         | 
| 252 | 
            +
                max_length=max_length,
         | 
| 253 | 
            +
            )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
            # image gcg datas
         | 
| 256 | 
            +
            glamm_data_root = './data/glamm_data/'
         | 
| 257 | 
            +
             | 
| 258 | 
            +
            refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
         | 
| 259 | 
            +
            refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            grandf_image_path = glamm_data_root + 'images/grandf/train/'
         | 
| 262 | 
            +
            grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
         | 
| 263 | 
            +
             | 
| 264 | 
            +
            flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
         | 
| 265 | 
            +
            flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
         | 
| 266 | 
            +
             | 
| 267 | 
            +
            psg_image_path = glamm_data_root + 'images/coco2017/'
         | 
| 268 | 
            +
            psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            glamm_refcocog_dataset = dict(
         | 
| 271 | 
            +
                type=RefCOCOgGCGDataset,
         | 
| 272 | 
            +
                image_folder=refcocog_image_path,
         | 
| 273 | 
            +
                data_path=refcocog_ann_file,
         | 
| 274 | 
            +
                tokenizer=tokenizer,
         | 
| 275 | 
            +
                max_length=max_length,
         | 
| 276 | 
            +
                special_tokens=special_tokens,
         | 
| 277 | 
            +
                template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
         | 
| 278 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 279 | 
            +
                lazy=True,
         | 
| 280 | 
            +
                repeats=1,
         | 
| 281 | 
            +
            )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            glamm_grandf_dataset = dict(
         | 
| 284 | 
            +
                type=GranDfGCGDataset,
         | 
| 285 | 
            +
                data_path=grandf_ann_file,
         | 
| 286 | 
            +
                image_folder=grandf_image_path,
         | 
| 287 | 
            +
                tokenizer=tokenizer,
         | 
| 288 | 
            +
                max_length=max_length,
         | 
| 289 | 
            +
                special_tokens=special_tokens,
         | 
| 290 | 
            +
                template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
         | 
| 291 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 292 | 
            +
                lazy=True,
         | 
| 293 | 
            +
                repeats=10,
         | 
| 294 | 
            +
            )
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            glamm_psg_dataset = dict(
         | 
| 297 | 
            +
                type=OpenPsgGCGDataset,
         | 
| 298 | 
            +
                data_path=psg_ann_file,
         | 
| 299 | 
            +
                image_folder=psg_image_path,
         | 
| 300 | 
            +
                tokenizer=tokenizer,
         | 
| 301 | 
            +
                max_length=max_length,
         | 
| 302 | 
            +
                special_tokens=special_tokens,
         | 
| 303 | 
            +
                template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
         | 
| 304 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 305 | 
            +
                lazy=True,
         | 
| 306 | 
            +
                repeats=1,
         | 
| 307 | 
            +
            )
         | 
| 308 | 
            +
             | 
| 309 | 
            +
            glamm_flickr_dataset = dict(
         | 
| 310 | 
            +
                type=FlickrGCGDataset,
         | 
| 311 | 
            +
                data_path=flickr_ann_file,
         | 
| 312 | 
            +
                image_folder=flickr_image_path,
         | 
| 313 | 
            +
                tokenizer=tokenizer,
         | 
| 314 | 
            +
                max_length=max_length,
         | 
| 315 | 
            +
                special_tokens=special_tokens,
         | 
| 316 | 
            +
                template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
         | 
| 317 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 318 | 
            +
                lazy=True,
         | 
| 319 | 
            +
                repeats=1,
         | 
| 320 | 
            +
            )
         | 
| 321 | 
            +
             | 
| 322 | 
            +
            # sam2 data
         | 
| 323 | 
            +
            data_sam2_folder = VIDEO_DATAS + 'segmentation_datasets/sam_v_full/'
         | 
| 324 | 
            +
            data_sam2_expression_file = './whole_pesudo_cap_v3/sam_v_final_v3.json'
         | 
| 325 | 
            +
             | 
| 326 | 
            +
            video_sam2_dataset = dict(
         | 
| 327 | 
            +
                type=VideoSAM2Dataset,
         | 
| 328 | 
            +
                sam2_folder=data_sam2_folder,
         | 
| 329 | 
            +
                expression_file=data_sam2_expression_file,
         | 
| 330 | 
            +
                tokenizer=tokenizer,
         | 
| 331 | 
            +
                template_map_fn=dict(
         | 
| 332 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 333 | 
            +
                max_length=max_length,
         | 
| 334 | 
            +
                lazy=True,
         | 
| 335 | 
            +
                repeats=4,
         | 
| 336 | 
            +
                special_tokens=special_tokens,
         | 
| 337 | 
            +
                extra_image_processor=extra_image_processor,
         | 
| 338 | 
            +
                sampled_frames=5,
         | 
| 339 | 
            +
                select_number=5,
         | 
| 340 | 
            +
            )
         | 
| 341 | 
            +
             | 
| 342 | 
            +
            # osprey
         | 
| 343 | 
            +
            data_osprey_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_conversation.json'
         | 
| 344 | 
            +
            data_osprey_image_folders = [
         | 
| 345 | 
            +
                IMG_DATAS+ 'coco/train2014/',
         | 
| 346 | 
            +
                IMG_DATAS + 'coco/val2014/',
         | 
| 347 | 
            +
                IMG_DATAS + 'coco/train2017/',
         | 
| 348 | 
            +
                IMG_DATAS + 'coco/val2017/',
         | 
| 349 | 
            +
            ]
         | 
| 350 | 
            +
             | 
| 351 | 
            +
            image_osprey_dataset = dict(
         | 
| 352 | 
            +
                type=OspreyDataset,
         | 
| 353 | 
            +
                image_folder=data_osprey_image_folders,
         | 
| 354 | 
            +
                data_path=data_osprey_file,
         | 
| 355 | 
            +
                tokenizer=tokenizer,
         | 
| 356 | 
            +
                template_map_fn=dict(
         | 
| 357 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 358 | 
            +
                max_length=max_length,
         | 
| 359 | 
            +
                lazy=True,
         | 
| 360 | 
            +
                repeats=1,
         | 
| 361 | 
            +
                special_tokens=special_tokens,
         | 
| 362 | 
            +
            )
         | 
| 363 | 
            +
             | 
| 364 | 
            +
            data_osprey_detail_description_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_detail_description.json'
         | 
| 365 | 
            +
            image_osprey_description_dataset = dict(
         | 
| 366 | 
            +
                type=OspreyDescriptionDataset,
         | 
| 367 | 
            +
                image_folder=data_osprey_image_folders,
         | 
| 368 | 
            +
                data_path=data_osprey_detail_description_file,
         | 
| 369 | 
            +
                tokenizer=tokenizer,
         | 
| 370 | 
            +
                template_map_fn=dict(
         | 
| 371 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 372 | 
            +
                max_length=max_length,
         | 
| 373 | 
            +
                lazy=True,
         | 
| 374 | 
            +
                repeats=1,
         | 
| 375 | 
            +
                special_tokens=special_tokens,
         | 
| 376 | 
            +
            )
         | 
| 377 | 
            +
             | 
| 378 | 
            +
            data_osprey_short_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_short_form.json'
         | 
| 379 | 
            +
            image_osprey_short_dataset = dict(
         | 
| 380 | 
            +
                type=OspreyShortDescriptionDataset,
         | 
| 381 | 
            +
                image_folder=data_osprey_image_folders,
         | 
| 382 | 
            +
                data_path=data_osprey_short_file,
         | 
| 383 | 
            +
                tokenizer=tokenizer,
         | 
| 384 | 
            +
                template_map_fn=dict(
         | 
| 385 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 386 | 
            +
                max_length=max_length,
         | 
| 387 | 
            +
                lazy=True,
         | 
| 388 | 
            +
                repeats=1,
         | 
| 389 | 
            +
                special_tokens=special_tokens,
         | 
| 390 | 
            +
            )
         | 
| 391 | 
            +
             | 
| 392 | 
            +
            data_osprey_part_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_part_level.json'
         | 
| 393 | 
            +
            image_osprey_part_dataset = dict(
         | 
| 394 | 
            +
                type=OspreyDataset,
         | 
| 395 | 
            +
                image_folder=data_osprey_image_folders,
         | 
| 396 | 
            +
                data_path=data_osprey_part_file,
         | 
| 397 | 
            +
                tokenizer=tokenizer,
         | 
| 398 | 
            +
                template_map_fn=dict(
         | 
| 399 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 400 | 
            +
                max_length=max_length,
         | 
| 401 | 
            +
                lazy=True,
         | 
| 402 | 
            +
                repeats=1,
         | 
| 403 | 
            +
                special_tokens=special_tokens,
         | 
| 404 | 
            +
            )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
            data_osprey_positive_neg_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_lvis_positive_negative.json'
         | 
| 407 | 
            +
            image_osprey_positive_neg_dataset = dict(
         | 
| 408 | 
            +
                type=OspreyDataset,
         | 
| 409 | 
            +
                image_folder=data_osprey_image_folders,
         | 
| 410 | 
            +
                data_path=data_osprey_positive_neg_file,
         | 
| 411 | 
            +
                tokenizer=tokenizer,
         | 
| 412 | 
            +
                template_map_fn=dict(
         | 
| 413 | 
            +
                    type=template_map_fn_factory, template=prompt_template),
         | 
| 414 | 
            +
                max_length=max_length,
         | 
| 415 | 
            +
                lazy=True,
         | 
| 416 | 
            +
                repeats=1,
         | 
| 417 | 
            +
                special_tokens=special_tokens,
         | 
| 418 | 
            +
            )
         | 
| 419 | 
            +
             | 
| 420 | 
            +
            train_dataset = dict(
         | 
| 421 | 
            +
                type=ConcatDataset, datasets=[
         | 
| 422 | 
            +
                    # sem seg
         | 
| 423 | 
            +
                    # semantic_seg_ade20k_dataset,
         | 
| 424 | 
            +
                    # ref seg
         | 
| 425 | 
            +
                    refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
         | 
| 426 | 
            +
                    refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
         | 
| 427 | 
            +
                    refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
         | 
| 428 | 
            +
                    refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
         | 
| 429 | 
            +
                    # image qa
         | 
| 430 | 
            +
                    llava_vqa_dataset,
         | 
| 431 | 
            +
                    # video res
         | 
| 432 | 
            +
                    video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
         | 
| 433 | 
            +
                    # video chat
         | 
| 434 | 
            +
                    video_qa_dataset,
         | 
| 435 | 
            +
                    # sam2 pesudo
         | 
| 436 | 
            +
                    video_sam2_dataset,
         | 
| 437 | 
            +
                    # gcg data
         | 
| 438 | 
            +
                    glamm_psg_dataset,
         | 
| 439 | 
            +
                    glamm_grandf_dataset,
         | 
| 440 | 
            +
                    glamm_flickr_dataset,
         | 
| 441 | 
            +
                    glamm_refcocog_dataset,
         | 
| 442 | 
            +
                    # visual prompt
         | 
| 443 | 
            +
                    image_osprey_dataset, image_osprey_description_dataset,
         | 
| 444 | 
            +
                    image_osprey_part_dataset, image_osprey_short_dataset,
         | 
| 445 | 
            +
                    image_osprey_positive_neg_dataset,
         | 
| 446 | 
            +
                ]
         | 
| 447 | 
            +
            )
         | 
| 448 | 
            +
            train_dataloader = dict(
         | 
| 449 | 
            +
                batch_size=batch_size,
         | 
| 450 | 
            +
                num_workers=dataloader_num_workers,
         | 
| 451 | 
            +
                dataset=train_dataset,
         | 
| 452 | 
            +
                sampler=dict(
         | 
| 453 | 
            +
                    type=LengthGroupedSampler,
         | 
| 454 | 
            +
                    length_property='modality_length',
         | 
| 455 | 
            +
                    per_device_batch_size=batch_size * accumulative_counts),
         | 
| 456 | 
            +
                collate_fn=dict(type=video_lisa_collate_fn)
         | 
| 457 | 
            +
            )
         | 
| 458 | 
            +
             | 
| 459 | 
            +
            #######################################################################
         | 
| 460 | 
            +
            #                    PART 4  Scheduler & Optimizer                    #
         | 
| 461 | 
            +
            #######################################################################
         | 
| 462 | 
            +
            # optimizer
         | 
| 463 | 
            +
            optim_wrapper = dict(
         | 
| 464 | 
            +
                type=AmpOptimWrapper,
         | 
| 465 | 
            +
                optimizer=dict(
         | 
| 466 | 
            +
                    type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
         | 
| 467 | 
            +
                clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
         | 
| 468 | 
            +
                accumulative_counts=accumulative_counts,
         | 
| 469 | 
            +
                loss_scale='dynamic',
         | 
| 470 | 
            +
                dtype='bfloat16'
         | 
| 471 | 
            +
            )
         | 
| 472 | 
            +
             | 
| 473 | 
            +
            # learning policy
         | 
| 474 | 
            +
            # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md  # noqa: E501
         | 
| 475 | 
            +
            param_scheduler = [
         | 
| 476 | 
            +
                dict(
         | 
| 477 | 
            +
                    type=LinearLR,
         | 
| 478 | 
            +
                    start_factor=1e-5,
         | 
| 479 | 
            +
                    by_epoch=True,
         | 
| 480 | 
            +
                    begin=0,
         | 
| 481 | 
            +
                    end=warmup_ratio * max_epochs,
         | 
| 482 | 
            +
                    convert_to_iter_based=True),
         | 
| 483 | 
            +
                dict(
         | 
| 484 | 
            +
                    type=CosineAnnealingLR,
         | 
| 485 | 
            +
                    eta_min=0.0,
         | 
| 486 | 
            +
                    by_epoch=True,
         | 
| 487 | 
            +
                    begin=warmup_ratio * max_epochs,
         | 
| 488 | 
            +
                    end=max_epochs,
         | 
| 489 | 
            +
                    convert_to_iter_based=True)
         | 
| 490 | 
            +
            ]
         | 
| 491 | 
            +
             | 
| 492 | 
            +
            # train, val, test setting
         | 
| 493 | 
            +
            train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
         | 
| 494 | 
            +
             | 
| 495 | 
            +
            #######################################################################
         | 
| 496 | 
            +
            #                           PART 5  Runtime                           #
         | 
| 497 | 
            +
            #######################################################################
         | 
| 498 | 
            +
            # Log the dialogue periodically during the training process, optional
         | 
| 499 | 
            +
            custom_hooks = [
         | 
| 500 | 
            +
                # dict(type=DatasetInfoHook, tokenizer=tokenizer),
         | 
| 501 | 
            +
            ]
         | 
| 502 | 
            +
             | 
| 503 | 
            +
            # configure default hooks
         | 
| 504 | 
            +
            default_hooks = dict(
         | 
| 505 | 
            +
                # record the time of every iteration.
         | 
| 506 | 
            +
                timer=dict(type=IterTimerHook),
         | 
| 507 | 
            +
                # print log every 10 iterations.
         | 
| 508 | 
            +
                logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
         | 
| 509 | 
            +
                # enable the parameter scheduler.
         | 
| 510 | 
            +
                param_scheduler=dict(type=ParamSchedulerHook),
         | 
| 511 | 
            +
                # save checkpoint per `save_steps`.
         | 
| 512 | 
            +
                checkpoint=dict(
         | 
| 513 | 
            +
                    type=CheckpointHook,
         | 
| 514 | 
            +
                    save_optimizer=False,
         | 
| 515 | 
            +
                    by_epoch=False,
         | 
| 516 | 
            +
                    interval=save_steps,
         | 
| 517 | 
            +
                    max_keep_ckpts=save_total_limit),
         | 
| 518 | 
            +
                # set sampler seed in distributed evrionment.
         | 
| 519 | 
            +
                sampler_seed=dict(type=DistSamplerSeedHook),
         | 
| 520 | 
            +
            )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
            # configure environment
         | 
| 523 | 
            +
            env_cfg = dict(
         | 
| 524 | 
            +
                # whether to enable cudnn benchmark
         | 
| 525 | 
            +
                cudnn_benchmark=False,
         | 
| 526 | 
            +
                # set multi process parameters
         | 
| 527 | 
            +
                mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
         | 
| 528 | 
            +
                # set distributed parameters
         | 
| 529 | 
            +
                dist_cfg=dict(backend='nccl'),
         | 
| 530 | 
            +
            )
         | 
| 531 | 
            +
             | 
| 532 | 
            +
            # set visualizer
         | 
| 533 | 
            +
            visualizer = None
         | 
| 534 | 
            +
             | 
| 535 | 
            +
            # set log level
         | 
| 536 | 
            +
            log_level = 'INFO'
         | 
| 537 | 
            +
             | 
| 538 | 
            +
            # load from which checkpoint
         | 
| 539 | 
            +
            load_from = None
         | 
| 540 | 
            +
             | 
| 541 | 
            +
            # whether to resume training from the loaded checkpoint
         | 
| 542 | 
            +
            resume = False
         | 
| 543 | 
            +
             | 
| 544 | 
            +
            # Defaults to use random seed and disable `deterministic`
         | 
| 545 | 
            +
            randomness = dict(seed=None, deterministic=False)
         | 
| 546 | 
            +
             | 
| 547 | 
            +
            # set log processor
         | 
| 548 | 
            +
            log_processor = dict(by_epoch=False)
         | 
    	
        projects/llava_sam2/datasets/ChatUniVi_Dataset.py
    ADDED
    
    | @@ -0,0 +1,389 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from typing import Literal
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 7 | 
            +
            from datasets import DatasetDict, load_from_disk
         | 
| 8 | 
            +
            from mmengine import print_log
         | 
| 9 | 
            +
            from PIL import Image
         | 
| 10 | 
            +
            from torch.utils.data import Dataset
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from xtuner.registry import BUILDER
         | 
| 14 | 
            +
            from xtuner.dataset.huggingface import build_origin_dataset
         | 
| 15 | 
            +
            import copy
         | 
| 16 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 17 | 
            +
            import json
         | 
| 18 | 
            +
            import cv2
         | 
| 19 | 
            +
            import torchvision.transforms as T
         | 
| 20 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 21 | 
            +
            from decord import VideoReader, cpu
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def _get_rawvideo_dec(video_path, select_frames=5):
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                if os.path.exists(video_path):
         | 
| 27 | 
            +
                    vreader = VideoReader(video_path, ctx=cpu(0))
         | 
| 28 | 
            +
                elif os.path.exists(video_path.replace('mkv', 'mp4')):
         | 
| 29 | 
            +
                    vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0))
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    print(video_path)
         | 
| 32 | 
            +
                    raise FileNotFoundError
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                fps = vreader.get_avg_fps()
         | 
| 35 | 
            +
                f_start = 0
         | 
| 36 | 
            +
                f_end = len(vreader) - 1
         | 
| 37 | 
            +
                num_frames = f_end - f_start + 1
         | 
| 38 | 
            +
                assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}'
         | 
| 39 | 
            +
                # T x 3 x H x W
         | 
| 40 | 
            +
                if num_frames <= select_frames:
         | 
| 41 | 
            +
                    sample_pos = range(f_start, f_end + 1)
         | 
| 42 | 
            +
                else:
         | 
| 43 | 
            +
                    split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int)
         | 
| 44 | 
            +
                    sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)]
         | 
| 45 | 
            +
                patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
         | 
| 46 | 
            +
                return patch_images
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class VideoChatUniViDataset(Dataset):
         | 
| 50 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 51 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 52 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 53 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 54 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
         | 
| 57 | 
            +
                FAST_IMG_START_TOKEN = '<fast_img>'
         | 
| 58 | 
            +
                FAST_IMG_END_TOKEN = '</fast_img>'
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self,
         | 
| 61 | 
            +
                             image_folder,
         | 
| 62 | 
            +
                             json_file,
         | 
| 63 | 
            +
                             extra_image_processor=None,
         | 
| 64 | 
            +
                             tokenizer=None,
         | 
| 65 | 
            +
                             sampled_frames=10,
         | 
| 66 | 
            +
                             offline_processed_text_folder=None,
         | 
| 67 | 
            +
                             template_map_fn=None,
         | 
| 68 | 
            +
                             max_length=2048,
         | 
| 69 | 
            +
                             lazy=True,
         | 
| 70 | 
            +
                             repeats=1,
         | 
| 71 | 
            +
                             special_tokens=None,
         | 
| 72 | 
            +
                             use_fast=False,
         | 
| 73 | 
            +
                             n_fast_images=50,
         | 
| 74 | 
            +
                             fast_pool_size=4,
         | 
| 75 | 
            +
                             arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
         | 
| 76 | 
            +
                             preprocessor=None,
         | 
| 77 | 
            +
                ):
         | 
| 78 | 
            +
                    assert lazy is True
         | 
| 79 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 80 | 
            +
                    self.sampled_frames = sampled_frames
         | 
| 81 | 
            +
                    assert offline_processed_text_folder or (json_file and tokenizer)
         | 
| 82 | 
            +
                    self.lazy = lazy
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.max_length = max_length
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 87 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 88 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 89 | 
            +
                        del self.template_map_fn['type']
         | 
| 90 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if offline_processed_text_folder and json_file:
         | 
| 93 | 
            +
                        print_log(
         | 
| 94 | 
            +
                            'Both `offline_processed_text_folder` and '
         | 
| 95 | 
            +
                            '`data_path` are set, and we load dataset from'
         | 
| 96 | 
            +
                            '`offline_processed_text_folder` '
         | 
| 97 | 
            +
                            f'({offline_processed_text_folder})',
         | 
| 98 | 
            +
                            logger='current',
         | 
| 99 | 
            +
                            level=logging.WARNING)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    if offline_processed_text_folder is not None:
         | 
| 102 | 
            +
                        raise NotImplementedError
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        json_datas = self.json_file_preprocess(json_file)
         | 
| 105 | 
            +
                        self.json_datas = json_datas
         | 
| 106 | 
            +
                        json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
         | 
| 107 | 
            +
                        if self.lazy:
         | 
| 108 | 
            +
                            self.text_data = build_origin_dataset(json_data, 'train')
         | 
| 109 | 
            +
                        else:
         | 
| 110 | 
            +
                            raise NotImplementedError
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.image_folder = image_folder
         | 
| 113 | 
            +
                    if extra_image_processor is not None:
         | 
| 114 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self.arch_type = arch_type
         | 
| 117 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 118 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
         | 
| 119 | 
            +
                        self.IMG_START_TOKEN = '<|vision_start|>'
         | 
| 120 | 
            +
                        self.IMG_END_TOKEN = '<|vision_end|>'
         | 
| 121 | 
            +
                    elif self.arch_type == 'llava':
         | 
| 122 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<image>'
         | 
| 123 | 
            +
                        self.IMG_START_TOKEN = ''
         | 
| 124 | 
            +
                        self.IMG_END_TOKEN = ''
         | 
| 125 | 
            +
                    self.repeats = repeats
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    self._system = ''
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 130 | 
            +
                    if self.arch_type == 'llava':
         | 
| 131 | 
            +
                        self.downsample_ratio = 1
         | 
| 132 | 
            +
                    self.image_size = 448
         | 
| 133 | 
            +
                    if self.arch_type == 'llava':
         | 
| 134 | 
            +
                        self.image_size = 336
         | 
| 135 | 
            +
                    patch_size = 14
         | 
| 136 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 137 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 138 | 
            +
                        self.patch_token = 1
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if preprocessor is None:
         | 
| 141 | 
            +
                        self.transformer = T.Compose([
         | 
| 142 | 
            +
                            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 143 | 
            +
                            T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 144 | 
            +
                            T.ToTensor(),
         | 
| 145 | 
            +
                            T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 146 | 
            +
                        ])
         | 
| 147 | 
            +
                        self.preprocessor = None
         | 
| 148 | 
            +
                    else:
         | 
| 149 | 
            +
                        self.transformer = None
         | 
| 150 | 
            +
                        self.preprocessor = BUILDER.build(preprocessor)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    self.arch_type = arch_type
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    if special_tokens is not None:
         | 
| 155 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    self.use_fast = use_fast
         | 
| 158 | 
            +
                    self.n_fast_images = n_fast_images
         | 
| 159 | 
            +
                    self.fast_pool_size = fast_pool_size
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # for visualization debug
         | 
| 162 | 
            +
                    self.save_folder = './work_dirs/video_debug/'
         | 
| 163 | 
            +
                    self.cur_number = 0
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    print("Video Chat dataset, include {} items.".format(len(self.text_data)))
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def __len__(self):
         | 
| 168 | 
            +
                    return len(self.text_data) * self.repeats
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                @property
         | 
| 171 | 
            +
                def modality_length(self):
         | 
| 172 | 
            +
                    length_list = []
         | 
| 173 | 
            +
                    for data_dict in self.text_data:
         | 
| 174 | 
            +
                        cur_len = 10000
         | 
| 175 | 
            +
                        length_list.append(cur_len)
         | 
| 176 | 
            +
                    return length_list
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def real_len(self):
         | 
| 179 | 
            +
                    return len(self.text_data)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def json_file_preprocess(self, json_file):
         | 
| 182 | 
            +
                    # prepare expression annotation files
         | 
| 183 | 
            +
                    with open(json_file, 'r') as f:
         | 
| 184 | 
            +
                        json_datas = json.load(f)
         | 
| 185 | 
            +
                    return json_datas
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def dataset_map_fn(self, data_dict, select_k=5):
         | 
| 188 | 
            +
                    assert 'video' in data_dict
         | 
| 189 | 
            +
                    # video
         | 
| 190 | 
            +
                    video_file = data_dict['video']
         | 
| 191 | 
            +
                    video_file = os.path.join(self.image_folder, video_file)
         | 
| 192 | 
            +
                    images = _get_rawvideo_dec(video_file, select_frames=select_k)
         | 
| 193 | 
            +
                    if self.use_fast:
         | 
| 194 | 
            +
                        fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images)
         | 
| 195 | 
            +
                    else:
         | 
| 196 | 
            +
                        fast_images = None
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    conversation = data_dict['conversations']
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # prepare text
         | 
| 201 | 
            +
                    if self.use_fast:
         | 
| 202 | 
            +
                        text_dict = self.prepare_text(
         | 
| 203 | 
            +
                            select_k, conversation, num_image_tokens=self.patch_token,
         | 
| 204 | 
            +
                            n_fast_images=len(fast_images),
         | 
| 205 | 
            +
                        )
         | 
| 206 | 
            +
                    else:
         | 
| 207 | 
            +
                        text_dict = self.prepare_text(
         | 
| 208 | 
            +
                            select_k, conversation, num_image_tokens=self.patch_token,
         | 
| 209 | 
            +
                        )
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
                    ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images}
         | 
| 213 | 
            +
                    return ret
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0):
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    if self.use_fast:
         | 
| 218 | 
            +
                        fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
         | 
| 219 | 
            +
                                      f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
         | 
| 220 | 
            +
                                      f'{self.FAST_IMG_END_TOKEN}' + '\n'
         | 
| 221 | 
            +
                    else:
         | 
| 222 | 
            +
                        fast_frame_token_str = ''
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    frame_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 225 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 226 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    questions = []
         | 
| 229 | 
            +
                    answers = []
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    for conv in conversation:
         | 
| 232 | 
            +
                        if conv['from'] == 'human':
         | 
| 233 | 
            +
                            questions.append(conv['value'].replace('<image>', ''))
         | 
| 234 | 
            +
                        else:
         | 
| 235 | 
            +
                            answers.append(conv['value'])
         | 
| 236 | 
            +
                    assert len(questions) == len(answers)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    qa_list = []
         | 
| 239 | 
            +
                    for i, (question, answer) in enumerate(zip(questions, answers)):
         | 
| 240 | 
            +
                        if i == 0:
         | 
| 241 | 
            +
                            frame_tokens = frame_token_str + '\n'
         | 
| 242 | 
            +
                            # frame_tokens = '=' + ' '
         | 
| 243 | 
            +
                            frame_tokens = frame_tokens * n_frames
         | 
| 244 | 
            +
                            frame_tokens = frame_tokens.strip()
         | 
| 245 | 
            +
                            frame_tokens = fast_frame_token_str + frame_tokens
         | 
| 246 | 
            +
                            qa_list.append(
         | 
| 247 | 
            +
                                {'from': 'human', 'value': frame_tokens + question}
         | 
| 248 | 
            +
                            )
         | 
| 249 | 
            +
                        else:
         | 
| 250 | 
            +
                            qa_list.append(
         | 
| 251 | 
            +
                                {'from': 'human', 'value': question}
         | 
| 252 | 
            +
                            )
         | 
| 253 | 
            +
                        qa_list.append(
         | 
| 254 | 
            +
                            {'from': 'gpt', 'value': answer}
         | 
| 255 | 
            +
                        )
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    input = ''
         | 
| 258 | 
            +
                    conversation = []
         | 
| 259 | 
            +
                    for msg in qa_list:
         | 
| 260 | 
            +
                        if msg['from'] == 'human':
         | 
| 261 | 
            +
                            input += msg['value']
         | 
| 262 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 263 | 
            +
                            conversation.append({'input': input, 'output': msg['value']})
         | 
| 264 | 
            +
                            input = ''
         | 
| 265 | 
            +
                        else:
         | 
| 266 | 
            +
                            raise NotImplementedError
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    # add system information
         | 
| 269 | 
            +
                    conversation[0].update({'system': self._system})
         | 
| 270 | 
            +
                    return {'conversation': conversation}
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                def __getitem__(self, index):
         | 
| 273 | 
            +
                    index = index % self.real_len()
         | 
| 274 | 
            +
                    selected_data_dict = copy.deepcopy(self.text_data[index])
         | 
| 275 | 
            +
                    data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
             | 
| 278 | 
            +
                    assert 'images' in data_dict.keys()
         | 
| 279 | 
            +
                    if self.use_fast:
         | 
| 280 | 
            +
                        assert 'fast_images' in data_dict.keys()
         | 
| 281 | 
            +
                    pixel_values = []
         | 
| 282 | 
            +
                    num_video_tokens = None
         | 
| 283 | 
            +
                    num_frame_tokens = None
         | 
| 284 | 
            +
                    if data_dict.get('images', None) is not None:
         | 
| 285 | 
            +
                        frames_files = data_dict['images']
         | 
| 286 | 
            +
                        for frame_image in frames_files:
         | 
| 287 | 
            +
                            frame_image = frame_image.convert('RGB')
         | 
| 288 | 
            +
                            ori_width, ori_height = frame_image.size
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                            if self.preprocessor is not None:
         | 
| 291 | 
            +
                                pass
         | 
| 292 | 
            +
                            else:
         | 
| 293 | 
            +
                                frame_image = self.transformer(frame_image)
         | 
| 294 | 
            +
                            pixel_values.append(frame_image)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                        if self.preprocessor is not None:
         | 
| 297 | 
            +
                            if self.arch_type == 'qwen':
         | 
| 298 | 
            +
                                _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
         | 
| 299 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 300 | 
            +
                                _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
         | 
| 301 | 
            +
                                num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
         | 
| 302 | 
            +
                                num_frames = _data_dict['image_grid_thw'].shape[0]
         | 
| 303 | 
            +
                                num_video_tokens = num_frame_tokens * num_frames
         | 
| 304 | 
            +
                            elif self.arch_type == 'llava':
         | 
| 305 | 
            +
                                _data_dict = self.preprocessor(pixel_values, do_resize=True,
         | 
| 306 | 
            +
                                                               size=(self.image_size, self.image_size))
         | 
| 307 | 
            +
                                _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
         | 
| 308 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 309 | 
            +
                            else:
         | 
| 310 | 
            +
                                raise NotImplementedError
         | 
| 311 | 
            +
                            data_dict.update(_data_dict)
         | 
| 312 | 
            +
                        else:
         | 
| 313 | 
            +
                            pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
         | 
| 314 | 
            +
                            data_dict['pixel_values'] = pixel_values
         | 
| 315 | 
            +
                    else:
         | 
| 316 | 
            +
                        data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
         | 
| 317 | 
            +
                        data_dict['masks'] = None
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    if num_video_tokens is not None:
         | 
| 320 | 
            +
                        assert self.patch_token == 1
         | 
| 321 | 
            +
                        input_str = data_dict['conversation'][0]['input']
         | 
| 322 | 
            +
                        input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
         | 
| 323 | 
            +
                        assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
         | 
| 324 | 
            +
                        data_dict['conversation'][0]['input'] = input_str
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 327 | 
            +
                    data_dict.update(result)
         | 
| 328 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
         | 
| 329 | 
            +
                    data_dict.update(result)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    # for fast branch
         | 
| 332 | 
            +
                    if self.use_fast:
         | 
| 333 | 
            +
                        fast_pixel_values = []
         | 
| 334 | 
            +
                        frames_files = data_dict['fast_images']
         | 
| 335 | 
            +
                        for frame_image in frames_files:
         | 
| 336 | 
            +
                            frame_image = frame_image.convert('RGB')
         | 
| 337 | 
            +
                            ori_width, ori_height = frame_image.size
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                            frame_image = self.transformer(frame_image)
         | 
| 340 | 
            +
                            fast_pixel_values.append(frame_image)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                        fast_pixel_values = torch.stack(fast_pixel_values, dim=0)  # (n_f, 3, h, w)
         | 
| 343 | 
            +
                        data_dict['fast_pixel_values'] = fast_pixel_values
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
                    # # for debug
         | 
| 347 | 
            +
                    # self.visualization_debug(data_dict)
         | 
| 348 | 
            +
                    # if self.cur_number < 10:
         | 
| 349 | 
            +
                    #     return self[random.randint(0, len(self))]
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    data_dict['type'] = 'video'
         | 
| 352 | 
            +
                    return data_dict
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                def visualization_debug(self, data_dict):
         | 
| 355 | 
            +
                    save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
         | 
| 356 | 
            +
                    if not os.path.exists(save_folder):
         | 
| 357 | 
            +
                        os.mkdir(save_folder)
         | 
| 358 | 
            +
                    self.cur_number += 1
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # images
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    show_images = []
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    pixel_values = data_dict['pixel_values']
         | 
| 365 | 
            +
                    save_folder_image = os.path.join(save_folder, 'image')
         | 
| 366 | 
            +
                    if not os.path.exists(save_folder_image):
         | 
| 367 | 
            +
                        os.mkdir(save_folder_image)
         | 
| 368 | 
            +
                    for i_image, image_pixel_value in enumerate(pixel_values):
         | 
| 369 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 370 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] * 0.2686
         | 
| 371 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] * 0.2613
         | 
| 372 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] * 0.2757
         | 
| 373 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] + 0.4814
         | 
| 374 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] + 0.4578
         | 
| 375 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] + 0.4082
         | 
| 376 | 
            +
                        image_pixel_value = image_pixel_value * 255
         | 
| 377 | 
            +
                        image_pixel_value = image_pixel_value.permute(1, 2, 0)
         | 
| 378 | 
            +
                        image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
         | 
| 379 | 
            +
                        # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
         | 
| 380 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 381 | 
            +
                        show_images.append(image_pixel_value)
         | 
| 382 | 
            +
                        cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    # text
         | 
| 385 | 
            +
                    input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
         | 
| 386 | 
            +
                    with open(os.path.join(save_folder, 'text.json'), 'w') as f:
         | 
| 387 | 
            +
                        json.dump([input_text], f)
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                    return
         | 
    	
        projects/llava_sam2/datasets/GCG_Dataset.py
    ADDED
    
    | @@ -0,0 +1,375 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 6 | 
            +
            from datasets import DatasetDict, load_from_disk
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from torch.utils.data import Dataset
         | 
| 9 | 
            +
            from pycocotools import mask
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import copy
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from xtuner.registry import BUILDER
         | 
| 14 | 
            +
            from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
         | 
| 15 | 
            +
            import torchvision.transforms as T
         | 
| 16 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN
         | 
| 17 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 18 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 19 | 
            +
            from .utils import dynamic_preprocess
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from .gcg_process import glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn, glamm_refcocog_map_fn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            class GCGDataset(Dataset):
         | 
| 24 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 25 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 26 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 27 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 30 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 31 | 
            +
                def __init__(self,
         | 
| 32 | 
            +
                             image_folder,
         | 
| 33 | 
            +
                             data_path=None,
         | 
| 34 | 
            +
                             tokenizer=None,
         | 
| 35 | 
            +
                             max_length=8196,
         | 
| 36 | 
            +
                             special_tokens=None,
         | 
| 37 | 
            +
                             template_map_fn=None,
         | 
| 38 | 
            +
                             extra_image_processor=None,
         | 
| 39 | 
            +
                             lazy=True,
         | 
| 40 | 
            +
                             repeats=1,
         | 
| 41 | 
            +
                             single_image_mode=False,
         | 
| 42 | 
            +
                ):
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    assert lazy
         | 
| 45 | 
            +
                    self.lazy = lazy
         | 
| 46 | 
            +
                    self.max_length = max_length
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    json_data = self.json_file_preprocess(data_path)
         | 
| 49 | 
            +
                    json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
         | 
| 50 | 
            +
                    self.text_data = build_origin_dataset(json_data, 'train')
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.image_folder = image_folder
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 55 | 
            +
                    if special_tokens is not None:
         | 
| 56 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 59 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 60 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 61 | 
            +
                        del self.template_map_fn['type']
         | 
| 62 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if extra_image_processor is not None:
         | 
| 65 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self.repeats = repeats
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self._system = ''
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 72 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 73 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 74 | 
            +
                    self.image_size = 448
         | 
| 75 | 
            +
                    self.use_thumbnail = True
         | 
| 76 | 
            +
                    patch_size = 14
         | 
| 77 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.transformer = T.Compose([
         | 
| 80 | 
            +
                        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 81 | 
            +
                        T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 82 | 
            +
                        T.ToTensor(),
         | 
| 83 | 
            +
                        T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 84 | 
            +
                    ])
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    if special_tokens is not None:
         | 
| 87 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    self.single_image_mode = single_image_mode
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def json_file_preprocess(self, data_path):
         | 
| 92 | 
            +
                    with open(data_path, 'r') as f:
         | 
| 93 | 
            +
                        json_data = json.load(f)
         | 
| 94 | 
            +
                    return json_data
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                @property
         | 
| 97 | 
            +
                def modality_length(self):
         | 
| 98 | 
            +
                    length_list = []
         | 
| 99 | 
            +
                    for data_dict in self.text_data:
         | 
| 100 | 
            +
                        if self.lazy:
         | 
| 101 | 
            +
                            cur_len = 100
         | 
| 102 | 
            +
                        else:
         | 
| 103 | 
            +
                            cur_len = len(data_dict['input_ids'])
         | 
| 104 | 
            +
                            if data_dict.get('image', None) is None:
         | 
| 105 | 
            +
                                cur_len = -cur_len
         | 
| 106 | 
            +
                        length_list.append(cur_len)
         | 
| 107 | 
            +
                    return length_list * self.repeats
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def __len__(self):
         | 
| 110 | 
            +
                    return len(self.text_data) * self.repeats
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def real_len(self):
         | 
| 113 | 
            +
                    return len(self.text_data)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def decode_mask(self, object_masks, ori_height, ori_width):
         | 
| 116 | 
            +
                    binary_masks = []
         | 
| 117 | 
            +
                    for object_mask in object_masks:
         | 
| 118 | 
            +
                        binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
         | 
| 119 | 
            +
                        for seg in object_mask:
         | 
| 120 | 
            +
                            rles = mask.frPyObjects([seg], ori_height, ori_width)
         | 
| 121 | 
            +
                            m = mask.decode(rles)
         | 
| 122 | 
            +
                            m = m.astype(np.uint8)
         | 
| 123 | 
            +
                            binary_mask += m.squeeze()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                        binary_masks.append(binary_mask)
         | 
| 126 | 
            +
                    if len(binary_masks) == 0:
         | 
| 127 | 
            +
                        return None
         | 
| 128 | 
            +
                    masks = np.stack(binary_masks, axis=0)
         | 
| 129 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 130 | 
            +
                    return masks
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 133 | 
            +
                    data_dict = glamm_refcocog_map_fn(data_dict)
         | 
| 134 | 
            +
                    return data_dict
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def replace_image_str(self, data_dict, image_str):
         | 
| 137 | 
            +
                    data_dict['conversation'][0]['input'] = \
         | 
| 138 | 
            +
                        data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
         | 
| 139 | 
            +
                    return data_dict
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def __getitem__(self, index):
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    index = index % self.real_len()
         | 
| 144 | 
            +
                    data_dict = copy.deepcopy(self.text_data[index])
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    # parse datasets
         | 
| 147 | 
            +
                    result = self.dataset_map_fn(data_dict)
         | 
| 148 | 
            +
                    data_dict.update(result)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    # process image
         | 
| 151 | 
            +
                    image_file = data_dict['image']
         | 
| 152 | 
            +
                    image = Image.open(os.path.join(self.image_folder,
         | 
| 153 | 
            +
                                                    image_file)).convert('RGB')
         | 
| 154 | 
            +
                    ori_width, ori_height = image.size
         | 
| 155 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 156 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 157 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 158 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 159 | 
            +
                        data_dict['g_pixel_values'] = g_pixel_values
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    if self.single_image_mode:
         | 
| 162 | 
            +
                        images = [image]
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 165 | 
            +
                                                    self.max_dynamic_patch,
         | 
| 166 | 
            +
                                                    self.image_size, self.use_thumbnail)
         | 
| 167 | 
            +
                    pixel_values = [self.transformer(image) for image in images]
         | 
| 168 | 
            +
                    pixel_values = torch.stack(pixel_values)
         | 
| 169 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 172 | 
            +
                    image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 173 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 174 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    data_dict = self.replace_image_str(data_dict, image_token_str)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 179 | 
            +
                    data_dict.update(result)
         | 
| 180 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
         | 
| 181 | 
            +
                                                  with_image_token=True)
         | 
| 182 | 
            +
                    data_dict.update(result)
         | 
| 183 | 
            +
                    # process mask
         | 
| 184 | 
            +
                    data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    if data_dict['masks'] is None:
         | 
| 187 | 
            +
                        return self.__getitem__(0)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    return data_dict
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            class RefCOCOgGCGDataset(GCGDataset):
         | 
| 192 | 
            +
                def __init__(self,
         | 
| 193 | 
            +
                             image_folder,
         | 
| 194 | 
            +
                             data_path=None,
         | 
| 195 | 
            +
                             tokenizer=None,
         | 
| 196 | 
            +
                             max_length=8196,
         | 
| 197 | 
            +
                             special_tokens=None,
         | 
| 198 | 
            +
                             template_map_fn=None,
         | 
| 199 | 
            +
                             extra_image_processor=None,
         | 
| 200 | 
            +
                             lazy=True,
         | 
| 201 | 
            +
                             repeats=1,
         | 
| 202 | 
            +
                             single_image_mode=False,
         | 
| 203 | 
            +
                             ):
         | 
| 204 | 
            +
                    super().__init__(
         | 
| 205 | 
            +
                        image_folder=image_folder,
         | 
| 206 | 
            +
                        data_path=data_path,
         | 
| 207 | 
            +
                        tokenizer=tokenizer,
         | 
| 208 | 
            +
                        max_length=max_length,
         | 
| 209 | 
            +
                        special_tokens=special_tokens,
         | 
| 210 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 211 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 212 | 
            +
                        lazy=lazy,
         | 
| 213 | 
            +
                        repeats=repeats,
         | 
| 214 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 215 | 
            +
                    )
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def json_file_preprocess(self, data_path):
         | 
| 218 | 
            +
                    json_data = json.load(open(data_path))
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # convert {id: dict} to dict(..., id=xx)
         | 
| 221 | 
            +
                    for idx in range(len(json_data)):
         | 
| 222 | 
            +
                        id = list(json_data[idx].keys())[0]
         | 
| 223 | 
            +
                        json_data[idx] = json_data[idx][id]
         | 
| 224 | 
            +
                        json_data[idx].update({'id': id})
         | 
| 225 | 
            +
                    return json_data
         | 
| 226 | 
            +
             | 
| 227 | 
            +
            class GranDfGCGDataset(GCGDataset):
         | 
| 228 | 
            +
                def __init__(self,
         | 
| 229 | 
            +
                             image_folder,
         | 
| 230 | 
            +
                             data_path=None,
         | 
| 231 | 
            +
                             tokenizer=None,
         | 
| 232 | 
            +
                             max_length=8196,
         | 
| 233 | 
            +
                             special_tokens=None,
         | 
| 234 | 
            +
                             template_map_fn=None,
         | 
| 235 | 
            +
                             extra_image_processor=None,
         | 
| 236 | 
            +
                             lazy=True,
         | 
| 237 | 
            +
                             repeats=1,
         | 
| 238 | 
            +
                             single_image_mode=False,
         | 
| 239 | 
            +
                             ):
         | 
| 240 | 
            +
                    super().__init__(
         | 
| 241 | 
            +
                        image_folder=image_folder,
         | 
| 242 | 
            +
                        data_path=data_path,
         | 
| 243 | 
            +
                        tokenizer=tokenizer,
         | 
| 244 | 
            +
                        max_length=max_length,
         | 
| 245 | 
            +
                        special_tokens=special_tokens,
         | 
| 246 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 247 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 248 | 
            +
                        lazy=lazy,
         | 
| 249 | 
            +
                        repeats=repeats,
         | 
| 250 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 251 | 
            +
                    )
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 254 | 
            +
                    data_dict = glamm_granf_map_fn(data_dict)
         | 
| 255 | 
            +
                    return data_dict
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def decode_mask(self, object_masks, ori_height, ori_width):
         | 
| 258 | 
            +
                    binary_masks = []
         | 
| 259 | 
            +
                    for object_mask in object_masks:
         | 
| 260 | 
            +
                        binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                        for rle in object_mask:
         | 
| 263 | 
            +
                            m = mask.decode(rle).astype(np.uint8)
         | 
| 264 | 
            +
                            binary_mask += m.squeeze()
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                        binary_masks.append(binary_mask)
         | 
| 267 | 
            +
                    if len(binary_masks) == 0:
         | 
| 268 | 
            +
                        return None
         | 
| 269 | 
            +
                    masks = np.stack(binary_masks, axis=0)
         | 
| 270 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 271 | 
            +
                    return masks
         | 
| 272 | 
            +
             | 
| 273 | 
            +
            class OpenPsgGCGDataset(GranDfGCGDataset):
         | 
| 274 | 
            +
                def __init__(self,
         | 
| 275 | 
            +
                             image_folder,
         | 
| 276 | 
            +
                             data_path=None,
         | 
| 277 | 
            +
                             tokenizer=None,
         | 
| 278 | 
            +
                             max_length=8196,
         | 
| 279 | 
            +
                             special_tokens=None,
         | 
| 280 | 
            +
                             template_map_fn=None,
         | 
| 281 | 
            +
                             extra_image_processor=None,
         | 
| 282 | 
            +
                             lazy=True,
         | 
| 283 | 
            +
                             repeats=1,
         | 
| 284 | 
            +
                             single_image_mode=False,
         | 
| 285 | 
            +
                             ):
         | 
| 286 | 
            +
                    super().__init__(
         | 
| 287 | 
            +
                        image_folder=image_folder,
         | 
| 288 | 
            +
                        data_path=data_path,
         | 
| 289 | 
            +
                        tokenizer=tokenizer,
         | 
| 290 | 
            +
                        max_length=max_length,
         | 
| 291 | 
            +
                        special_tokens=special_tokens,
         | 
| 292 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 293 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 294 | 
            +
                        lazy=lazy,
         | 
| 295 | 
            +
                        repeats=repeats,
         | 
| 296 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 297 | 
            +
                    )
         | 
| 298 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 299 | 
            +
                    data_dict = glamm_openpsg_map_fn(data_dict)
         | 
| 300 | 
            +
                    return data_dict
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            class FlickrGCGDataset(GCGDataset):
         | 
| 304 | 
            +
                def __init__(self,
         | 
| 305 | 
            +
                             image_folder,
         | 
| 306 | 
            +
                             data_path=None,
         | 
| 307 | 
            +
                             tokenizer=None,
         | 
| 308 | 
            +
                             max_length=8196,
         | 
| 309 | 
            +
                             special_tokens=None,
         | 
| 310 | 
            +
                             template_map_fn=None,
         | 
| 311 | 
            +
                             extra_image_processor=None,
         | 
| 312 | 
            +
                             lazy=True,
         | 
| 313 | 
            +
                             repeats=1,
         | 
| 314 | 
            +
                             single_image_mode=False,
         | 
| 315 | 
            +
                             ):
         | 
| 316 | 
            +
                    super().__init__(
         | 
| 317 | 
            +
                        image_folder=image_folder,
         | 
| 318 | 
            +
                        data_path=data_path,
         | 
| 319 | 
            +
                        tokenizer=tokenizer,
         | 
| 320 | 
            +
                        max_length=max_length,
         | 
| 321 | 
            +
                        special_tokens=special_tokens,
         | 
| 322 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 323 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 324 | 
            +
                        lazy=lazy,
         | 
| 325 | 
            +
                        repeats=repeats,
         | 
| 326 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 330 | 
            +
                    data_dict = glamm_flickr_map_fn(data_dict)
         | 
| 331 | 
            +
                    return data_dict
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                def json_file_preprocess(self, data_path):
         | 
| 334 | 
            +
                    def filter_images(data_infos, min_size):
         | 
| 335 | 
            +
                        return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    # convert {id: dict} to dict(..., id=xx)
         | 
| 338 | 
            +
                    from pycocotools.coco import COCO
         | 
| 339 | 
            +
                    self.coco = COCO(data_path)
         | 
| 340 | 
            +
                    self.image_ids = self.coco.getImgIds()
         | 
| 341 | 
            +
                    data_infos = []
         | 
| 342 | 
            +
                    total_ann_ids = []
         | 
| 343 | 
            +
                    removed_img_count = 0
         | 
| 344 | 
            +
                    for img_id in self.image_ids:
         | 
| 345 | 
            +
                        info = self.coco.loadImgs([img_id])[0]
         | 
| 346 | 
            +
                        if len(info['caption'].split(' ')) < 3:
         | 
| 347 | 
            +
                            removed_img_count += 1
         | 
| 348 | 
            +
                            continue
         | 
| 349 | 
            +
                        info['filename'] = info['file_name'].split('_')[-1]
         | 
| 350 | 
            +
                        info['height'] = int(info['height'])
         | 
| 351 | 
            +
                        info['width'] = int(info['width'])
         | 
| 352 | 
            +
                        data_infos.append(info)
         | 
| 353 | 
            +
                        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
         | 
| 354 | 
            +
                        total_ann_ids.extend(ann_ids)
         | 
| 355 | 
            +
                    assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
         | 
| 356 | 
            +
                    print(f'Removed {removed_img_count} images.')
         | 
| 357 | 
            +
                    data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    # obtain_annotations
         | 
| 360 | 
            +
                    for data_info in data_infos:
         | 
| 361 | 
            +
                        ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
         | 
| 362 | 
            +
                        ann_info = self.coco.loadAnns(ann_ids)
         | 
| 363 | 
            +
                        data_info.update({'ann_info': ann_info})
         | 
| 364 | 
            +
                    return data_infos
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def decode_mask(self, object_masks, ori_height, ori_width):
         | 
| 367 | 
            +
                    binary_masks = []
         | 
| 368 | 
            +
                    for object_mask in object_masks:
         | 
| 369 | 
            +
                        binary_mask = mask.decode(object_mask).astype(np.uint8)
         | 
| 370 | 
            +
                        binary_masks.append(binary_mask)
         | 
| 371 | 
            +
                    if len(binary_masks) == 0:
         | 
| 372 | 
            +
                        return None
         | 
| 373 | 
            +
                    masks = np.stack(binary_masks, axis=0)
         | 
| 374 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 375 | 
            +
                    return masks
         | 
    	
        projects/llava_sam2/datasets/Grand_Dataset.py
    ADDED
    
    | @@ -0,0 +1,241 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 7 | 
            +
            from datasets import DatasetDict, load_from_disk
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
            from torch.utils.data import Dataset
         | 
| 10 | 
            +
            from pycocotools import mask
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import copy
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from xtuner.registry import BUILDER
         | 
| 15 | 
            +
            from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
         | 
| 16 | 
            +
            import torchvision.transforms as T
         | 
| 17 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN
         | 
| 18 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 19 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 20 | 
            +
            from .utils import dynamic_preprocess
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from .grand_process import glamm_grand_map_fn
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            class GranDDataset(Dataset):
         | 
| 25 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 26 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 27 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 28 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 31 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 32 | 
            +
                def __init__(self,
         | 
| 33 | 
            +
                             image_folder,
         | 
| 34 | 
            +
                             json_folder=None,
         | 
| 35 | 
            +
                             tokenizer=None,
         | 
| 36 | 
            +
                             max_length=8196,
         | 
| 37 | 
            +
                             special_tokens=None,
         | 
| 38 | 
            +
                             template_map_fn=None,
         | 
| 39 | 
            +
                             extra_image_processor=None,
         | 
| 40 | 
            +
                             lazy=True,
         | 
| 41 | 
            +
                             repeats=1,
         | 
| 42 | 
            +
                             single_image_mode=False,
         | 
| 43 | 
            +
                             image_list_save_path='./work_dirs/grand_image.json',
         | 
| 44 | 
            +
                             json_list_save_path='./work_dirs/grand_jsons.json',
         | 
| 45 | 
            +
                ):
         | 
| 46 | 
            +
                    super().__init__()
         | 
| 47 | 
            +
                    assert lazy
         | 
| 48 | 
            +
                    self.lazy = lazy
         | 
| 49 | 
            +
                    self.max_length = max_length
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.image_list_save_path = image_list_save_path
         | 
| 52 | 
            +
                    self.json_list_save_path = json_list_save_path
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    json_files, image_path_dict = self.json_file_preprocess(image_folder, json_folder)
         | 
| 55 | 
            +
                    self.json_data = json_files
         | 
| 56 | 
            +
                    self.image_path_dict = image_path_dict
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.image_folder = image_folder
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 61 | 
            +
                    if special_tokens is not None:
         | 
| 62 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 65 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 66 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 67 | 
            +
                        del self.template_map_fn['type']
         | 
| 68 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    if extra_image_processor is not None:
         | 
| 71 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.repeats = repeats
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self._system = ''
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 78 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 79 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 80 | 
            +
                    self.image_size = 448
         | 
| 81 | 
            +
                    self.use_thumbnail = True
         | 
| 82 | 
            +
                    patch_size = 14
         | 
| 83 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.transformer = T.Compose([
         | 
| 86 | 
            +
                        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 87 | 
            +
                        T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 88 | 
            +
                        T.ToTensor(),
         | 
| 89 | 
            +
                        T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 90 | 
            +
                    ])
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    if special_tokens is not None:
         | 
| 93 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    self.single_image_mode = single_image_mode
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def json_file_preprocess(self, image_folder, json_folder):
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # list jsons
         | 
| 100 | 
            +
                    print("Processing GRAND json files !!!")
         | 
| 101 | 
            +
                    if os.path.exists(self.json_list_save_path):
         | 
| 102 | 
            +
                        with open(self.json_list_save_path, 'r') as f:
         | 
| 103 | 
            +
                            json_files = json.load(f)
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        json_files = os.listdir(json_folder)
         | 
| 106 | 
            +
                        _json_files = []
         | 
| 107 | 
            +
                        for _file in json_files:
         | 
| 108 | 
            +
                            if '.json' in _file:
         | 
| 109 | 
            +
                                _json_files.append(os.path.join(json_folder, _file))
         | 
| 110 | 
            +
                        json_files = _json_files
         | 
| 111 | 
            +
                        with open(self.json_list_save_path, 'w') as f:
         | 
| 112 | 
            +
                            json.dump(json_files, f)
         | 
| 113 | 
            +
                    print(f"Finished, {len(json_files)} json files !")
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # list images
         | 
| 116 | 
            +
                    print("Processing GRAND image files !!!")
         | 
| 117 | 
            +
                    if os.path.exists(self.image_list_save_path):
         | 
| 118 | 
            +
                        with open(self.image_list_save_path, 'r') as f:
         | 
| 119 | 
            +
                            image_path_dict = json.load(f)
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        sub_folders = os.listdir(image_folder)
         | 
| 122 | 
            +
                        _sub_folders = []
         | 
| 123 | 
            +
                        for folder_name in sub_folders:
         | 
| 124 | 
            +
                            if 'sa_00' in folder_name:
         | 
| 125 | 
            +
                                _sub_folders.append(folder_name)
         | 
| 126 | 
            +
                        sub_folders = _sub_folders
         | 
| 127 | 
            +
                        sub_folders = [os.path.join(image_folder, folder_name) for folder_name in sub_folders]
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        image_path_dict = {}
         | 
| 130 | 
            +
                        for sub_folder in sub_folders:
         | 
| 131 | 
            +
                            files = os.listdir(sub_folder)
         | 
| 132 | 
            +
                            for _file in files:
         | 
| 133 | 
            +
                                if '.jpg' in _file:
         | 
| 134 | 
            +
                                    image_path_dict[_file] = os.path.join(sub_folder, _file)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        with open(self.image_list_save_path, 'w') as f:
         | 
| 137 | 
            +
                            json.dump(image_path_dict, f)
         | 
| 138 | 
            +
                    print(f"Finished, {len(image_path_dict)} image files !")
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    return json_files, image_path_dict
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                @property
         | 
| 143 | 
            +
                def modality_length(self):
         | 
| 144 | 
            +
                    length_list = [10000] * len(self.json_data)
         | 
| 145 | 
            +
                    return length_list * self.repeats
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def __len__(self):
         | 
| 148 | 
            +
                    return len(self.json_data) * self.repeats
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def real_len(self):
         | 
| 151 | 
            +
                    return len(self.json_data)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def decode_mask(self, object_masks, ori_height, ori_width):
         | 
| 154 | 
            +
                    binary_masks = []
         | 
| 155 | 
            +
                    for object_mask in object_masks:
         | 
| 156 | 
            +
                        binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
         | 
| 157 | 
            +
                        for seg in object_mask:
         | 
| 158 | 
            +
                            m = mask.decode(seg)
         | 
| 159 | 
            +
                            m = m.astype(np.uint8)
         | 
| 160 | 
            +
                            binary_mask += m.squeeze()
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        binary_masks.append(binary_mask)
         | 
| 163 | 
            +
                    if len(binary_masks) == 0:
         | 
| 164 | 
            +
                        return None
         | 
| 165 | 
            +
                    masks = np.stack(binary_masks, axis=0)
         | 
| 166 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 167 | 
            +
                    return masks
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 170 | 
            +
                    data_dict = glamm_grand_map_fn(data_dict)
         | 
| 171 | 
            +
                    return data_dict
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def replace_image_str(self, data_dict, image_str):
         | 
| 174 | 
            +
                    data_dict['conversation'][0]['input'] = \
         | 
| 175 | 
            +
                        data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
         | 
| 176 | 
            +
                    return data_dict
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def __getitem__(self, index):
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    index = index % self.real_len()
         | 
| 181 | 
            +
                    json_file_path = self.json_data[index]
         | 
| 182 | 
            +
                    with open(json_file_path, 'r') as f:
         | 
| 183 | 
            +
                        json_dict = json.load(f)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    image_name = list(json_dict.keys())[0]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if image_name not in self.image_path_dict.keys():
         | 
| 188 | 
            +
                        return self.__getitem__(random.randint(0, len(self.json_data) - 1))
         | 
| 189 | 
            +
                    image_path = self.image_path_dict[image_name]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    json_dict = json_dict[image_name]
         | 
| 192 | 
            +
                    # parse datasets
         | 
| 193 | 
            +
                    result = self.dataset_map_fn(json_dict)
         | 
| 194 | 
            +
                    json_dict.update(result)
         | 
| 195 | 
            +
                    data_dict = json_dict
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    data_dict['image'] = image_path
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    # process image
         | 
| 200 | 
            +
                    image_file = data_dict['image']
         | 
| 201 | 
            +
                    try:
         | 
| 202 | 
            +
                        image = Image.open(os.path.join(self.image_folder,
         | 
| 203 | 
            +
                                                        image_file)).convert('RGB')
         | 
| 204 | 
            +
                    except:
         | 
| 205 | 
            +
                        return self.__getitem__(random.randint(0, len(self.json_data) - 1))
         | 
| 206 | 
            +
                    ori_width, ori_height = image.size
         | 
| 207 | 
            +
                    if hasattr(self, 'extra_image_processor'):
         | 
| 208 | 
            +
                        g_image = np.array(image)  # for grounding
         | 
| 209 | 
            +
                        g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 210 | 
            +
                        g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 211 | 
            +
                        data_dict['g_pixel_values'] = g_pixel_values
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    if self.single_image_mode:
         | 
| 214 | 
            +
                        images = [image]
         | 
| 215 | 
            +
                    else:
         | 
| 216 | 
            +
                        images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 217 | 
            +
                                                    self.max_dynamic_patch,
         | 
| 218 | 
            +
                                                    self.image_size, self.use_thumbnail)
         | 
| 219 | 
            +
                    pixel_values = [self.transformer(image) for image in images]
         | 
| 220 | 
            +
                    pixel_values = torch.stack(pixel_values)
         | 
| 221 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 224 | 
            +
                    image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 225 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 226 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    data_dict = self.replace_image_str(data_dict, image_token_str)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 231 | 
            +
                    data_dict.update(result)
         | 
| 232 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
         | 
| 233 | 
            +
                                                  with_image_token=True)
         | 
| 234 | 
            +
                    data_dict.update(result)
         | 
| 235 | 
            +
                    # process mask
         | 
| 236 | 
            +
                    data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if data_dict['masks'] is None:
         | 
| 239 | 
            +
                        return self.__getitem__(random.randint(0, len(self.json_data) - 1))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    return data_dict
         | 
    	
        projects/llava_sam2/datasets/MeVIS_Dataset.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .ReVOS_Dataset import VideoReVOSDataset
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class VideoMeVISDataset(VideoReVOSDataset):
         | 
| 5 | 
            +
                pass
         | 
    	
        projects/llava_sam2/datasets/Osprey_Dataset.py
    ADDED
    
    | @@ -0,0 +1,463 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 6 | 
            +
            from datasets import DatasetDict, load_from_disk
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from torch.utils.data import Dataset
         | 
| 9 | 
            +
            from pycocotools import mask as maskUtils
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import copy
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from xtuner.registry import BUILDER
         | 
| 14 | 
            +
            from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
         | 
| 15 | 
            +
            import torchvision.transforms as T
         | 
| 16 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN
         | 
| 17 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 18 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 19 | 
            +
            from .utils import dynamic_preprocess
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import random
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            import torch.nn.functional as F
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            class OspreyDataset(Dataset):
         | 
| 26 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 27 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 28 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 29 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                LIMIT = ''
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                VP_START_TOKEN = '<vp>'
         | 
| 34 | 
            +
                VP_END_TOKEN = '</vp>'
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 37 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 38 | 
            +
                def __init__(self,
         | 
| 39 | 
            +
                             image_folder,
         | 
| 40 | 
            +
                             data_path=None,
         | 
| 41 | 
            +
                             tokenizer=None,
         | 
| 42 | 
            +
                             max_length=8196,
         | 
| 43 | 
            +
                             special_tokens=None,
         | 
| 44 | 
            +
                             template_map_fn=None,
         | 
| 45 | 
            +
                             extra_image_processor=None,
         | 
| 46 | 
            +
                             lazy=True,
         | 
| 47 | 
            +
                             repeats=1,
         | 
| 48 | 
            +
                             single_image_mode=False,
         | 
| 49 | 
            +
                ):
         | 
| 50 | 
            +
                    super().__init__()
         | 
| 51 | 
            +
                    assert lazy
         | 
| 52 | 
            +
                    self.lazy = lazy
         | 
| 53 | 
            +
                    self.max_length = max_length
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    json_data = self.json_file_preprocess(data_path)
         | 
| 56 | 
            +
                    self.text_data = json_data
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.image_folder = image_folder
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 61 | 
            +
                    if special_tokens is not None:
         | 
| 62 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 65 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 66 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 67 | 
            +
                        del self.template_map_fn['type']
         | 
| 68 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    if extra_image_processor is not None:
         | 
| 71 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.repeats = repeats
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self._system = ''
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 78 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 79 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 80 | 
            +
                    self.image_size = 448
         | 
| 81 | 
            +
                    self.use_thumbnail = True
         | 
| 82 | 
            +
                    patch_size = 14
         | 
| 83 | 
            +
                    self.patch_size = patch_size
         | 
| 84 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.transformer = T.Compose([
         | 
| 87 | 
            +
                        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 88 | 
            +
                        T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 89 | 
            +
                        T.ToTensor(),
         | 
| 90 | 
            +
                        T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 91 | 
            +
                    ])
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    if special_tokens is not None:
         | 
| 94 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.single_image_mode = single_image_mode
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def json_file_preprocess(self, data_path):
         | 
| 99 | 
            +
                    with open(data_path, 'r') as f:
         | 
| 100 | 
            +
                        json_data = json.load(f)
         | 
| 101 | 
            +
                    return json_data
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                @property
         | 
| 104 | 
            +
                def modality_length(self):
         | 
| 105 | 
            +
                    length_list = []
         | 
| 106 | 
            +
                    for data_dict in self.text_data:
         | 
| 107 | 
            +
                        if self.lazy:
         | 
| 108 | 
            +
                            cur_len = 100
         | 
| 109 | 
            +
                        else:
         | 
| 110 | 
            +
                            cur_len = len(data_dict['input_ids'])
         | 
| 111 | 
            +
                            if data_dict.get('image', None) is None:
         | 
| 112 | 
            +
                                cur_len = -cur_len
         | 
| 113 | 
            +
                        length_list.append(cur_len)
         | 
| 114 | 
            +
                    return length_list * self.repeats
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def __len__(self):
         | 
| 117 | 
            +
                    return len(self.text_data) * self.repeats
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def real_len(self):
         | 
| 120 | 
            +
                    return len(self.text_data)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def annToMask(self, mask_ann, h, w):
         | 
| 123 | 
            +
                    if isinstance(mask_ann, list):
         | 
| 124 | 
            +
                        rles = maskUtils.frPyObjects(mask_ann, h, w)
         | 
| 125 | 
            +
                        rle = maskUtils.merge(rles)
         | 
| 126 | 
            +
                    elif isinstance(mask_ann['counts'], list):
         | 
| 127 | 
            +
                        # uncompressed RLE
         | 
| 128 | 
            +
                        rle = maskUtils.frPyObjects(mask_ann, h, w)
         | 
| 129 | 
            +
                    else:
         | 
| 130 | 
            +
                        # rle
         | 
| 131 | 
            +
                        rle = mask_ann
         | 
| 132 | 
            +
                    mask = maskUtils.decode(rle)
         | 
| 133 | 
            +
                    return mask
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def decode_mask(self, object_masks, ori_height, ori_width):
         | 
| 136 | 
            +
                    binary_masks = []
         | 
| 137 | 
            +
                    for object_mask in object_masks:
         | 
| 138 | 
            +
                        binary_mask = self.annToMask(object_mask, ori_height, ori_width)
         | 
| 139 | 
            +
                        binary_masks.append(binary_mask)
         | 
| 140 | 
            +
                    if len(binary_masks) == 0:
         | 
| 141 | 
            +
                        return None
         | 
| 142 | 
            +
                    masks = np.stack(binary_masks, axis=0)
         | 
| 143 | 
            +
                    masks = torch.from_numpy(masks)
         | 
| 144 | 
            +
                    return masks
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def _process_conversation(self, converations, n_regions, region_pixels):
         | 
| 147 | 
            +
                    start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
         | 
| 148 | 
            +
                    for i in range(n_regions):
         | 
| 149 | 
            +
                        start_region_str = start_region_str + \
         | 
| 150 | 
            +
                                           f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
         | 
| 151 | 
            +
                        if i == n_regions - 1:
         | 
| 152 | 
            +
                            start_region_str = start_region_str + '.\n'
         | 
| 153 | 
            +
                        else:
         | 
| 154 | 
            +
                            start_region_str = start_region_str + ', '
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    for i, item in enumerate(converations):
         | 
| 157 | 
            +
                        item['value'] = item['value'].replace('<', '').replace('>', '')
         | 
| 158 | 
            +
                        if item['from'] == 'human':
         | 
| 159 | 
            +
                            item['value'] = item['value'] + self.LIMIT
         | 
| 160 | 
            +
                        # first conv process
         | 
| 161 | 
            +
                        if i == 0:
         | 
| 162 | 
            +
                            assert item['from'] == "human"
         | 
| 163 | 
            +
                            item['value'] =  start_region_str + item['value']
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    messages = converations
         | 
| 166 | 
            +
                    input = ''
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    conversation = []
         | 
| 169 | 
            +
                    while messages and messages[0]['from'] == 'gpt':
         | 
| 170 | 
            +
                        # Skip the first one if it is from gpt
         | 
| 171 | 
            +
                        messages = messages[1:]
         | 
| 172 | 
            +
                    for msg in messages:
         | 
| 173 | 
            +
                        if msg['from'] == 'human':
         | 
| 174 | 
            +
                            if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 175 | 
            +
                                msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 176 | 
            +
                                                                    '').strip()
         | 
| 177 | 
            +
                                msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 178 | 
            +
                                msg['value'] = msg['value'].strip()
         | 
| 179 | 
            +
                            input += msg['value']
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 182 | 
            +
                            conversation.append({'input': input, 'output': msg['value']})
         | 
| 183 | 
            +
                            input = ''
         | 
| 184 | 
            +
                        else:
         | 
| 185 | 
            +
                            raise NotImplementedError
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    return conversation
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def _get_region_infos(self, masks):
         | 
| 190 | 
            +
                    # masks tensor, (n_obj, h, w)
         | 
| 191 | 
            +
                    masks = F.interpolate(
         | 
| 192 | 
            +
                        masks.unsqueeze(0),
         | 
| 193 | 
            +
                        size=(int(self.image_size // self.patch_size * self.downsample_ratio),
         | 
| 194 | 
            +
                              int(self.image_size // self.patch_size * self.downsample_ratio)),
         | 
| 195 | 
            +
                        mode='nearest').squeeze(0)
         | 
| 196 | 
            +
                    region_pixels = []
         | 
| 197 | 
            +
                    for mask in masks:
         | 
| 198 | 
            +
                        region_pixels.append(mask.bool().to(torch.int64).sum())
         | 
| 199 | 
            +
                    return masks, region_pixels
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 202 | 
            +
                    file_name = data_dict['file_name'] # image file name
         | 
| 203 | 
            +
                    conversations = data_dict['conversations']
         | 
| 204 | 
            +
                    masks = [anno["segmentation"] for anno in data_dict["annotation"]]
         | 
| 205 | 
            +
                    height = data_dict['height']
         | 
| 206 | 
            +
                    width = data_dict['width']
         | 
| 207 | 
            +
                    _ret = {}
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    _ret['image'] = file_name
         | 
| 210 | 
            +
                    _ret['height'] = height
         | 
| 211 | 
            +
                    _ret['width'] = width
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    masks = self.decode_mask(masks, height, width)
         | 
| 214 | 
            +
                    masks, region_pixels = self._get_region_infos(masks)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    if masks is None:
         | 
| 217 | 
            +
                        return None
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    conversations = self._process_conversation(conversations, len(masks), region_pixels)
         | 
| 220 | 
            +
                    _ret['conversation'] = conversations
         | 
| 221 | 
            +
                    _ret['prompt_masks'] = masks
         | 
| 222 | 
            +
                    return _ret
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def replace_image_str(self, data_dict, image_str):
         | 
| 225 | 
            +
                    data_dict['conversation'][0]['input'] = \
         | 
| 226 | 
            +
                        data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
         | 
| 227 | 
            +
                    return data_dict
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                def __getitem__(self, index):
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    index = index % self.real_len()
         | 
| 232 | 
            +
                    data_dict = copy.deepcopy(self.text_data[index])
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    # parse datasets
         | 
| 235 | 
            +
                    result = self.dataset_map_fn(data_dict) # {'image', 'height', 'width', 'conversation', 'masks'}
         | 
| 236 | 
            +
                    if result is None or result['prompt_masks'] is None:
         | 
| 237 | 
            +
                        return self.__getitem__(0)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    data_dict = result
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    # process image
         | 
| 242 | 
            +
                    image_file = data_dict['image']
         | 
| 243 | 
            +
                    if isinstance(self.image_folder, list):
         | 
| 244 | 
            +
                        for image_folder in self.image_folder:
         | 
| 245 | 
            +
                            image_path = os.path.join(image_folder, image_file)
         | 
| 246 | 
            +
                            if os.path.exists(image_path):
         | 
| 247 | 
            +
                                image = Image.open(image_path).convert('RGB')
         | 
| 248 | 
            +
                                break
         | 
| 249 | 
            +
                    else:
         | 
| 250 | 
            +
                        image = Image.open(os.path.join(self.image_folder,
         | 
| 251 | 
            +
                                                        image_file)).convert('RGB')
         | 
| 252 | 
            +
                    ori_width, ori_height = image.size
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    if self.single_image_mode:
         | 
| 255 | 
            +
                        images = [image]
         | 
| 256 | 
            +
                    else:
         | 
| 257 | 
            +
                        images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 258 | 
            +
                                                    self.max_dynamic_patch,
         | 
| 259 | 
            +
                                                    self.image_size, self.use_thumbnail)
         | 
| 260 | 
            +
                    vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
         | 
| 261 | 
            +
                    data_dict['vp_overall_mask'] = vp_overall_mask
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    pixel_values = [self.transformer(image) for image in images]
         | 
| 264 | 
            +
                    pixel_values = torch.stack(pixel_values)
         | 
| 265 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 268 | 
            +
                    image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 269 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 270 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    data_dict = self.replace_image_str(data_dict, image_token_str)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 275 | 
            +
                    data_dict.update(result)
         | 
| 276 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
         | 
| 277 | 
            +
                                                  with_image_token=True)
         | 
| 278 | 
            +
                    data_dict.update(result)
         | 
| 279 | 
            +
                    # process mask
         | 
| 280 | 
            +
                    # data_dict['prompt_masks'] = data_dict['prompt_masks']
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    if data_dict['prompt_masks'] is None:
         | 
| 283 | 
            +
                        return self.__getitem__(0)
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    return data_dict
         | 
| 286 | 
            +
             | 
| 287 | 
            +
             | 
| 288 | 
            +
            DETAILED_QUESTIONS =  [
         | 
| 289 | 
            +
                'Can you provide me with a detailed description of the region in the picture marked by <region>?',
         | 
| 290 | 
            +
                "I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
         | 
| 291 | 
            +
                'What can you tell me about the region indicated by <region> in the image?',
         | 
| 292 | 
            +
                "I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
         | 
| 293 | 
            +
                'Could you describe the region shown as <region> in the picture in great detail?',
         | 
| 294 | 
            +
                'What details can you give me about the region outlined by <region> in the photo?',
         | 
| 295 | 
            +
                'Please provide me with a comprehensive description of the region marked with <region> in the image.',
         | 
| 296 | 
            +
                'Can you give me a detailed account of the region labeled as <region> in the picture?',
         | 
| 297 | 
            +
                "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
         | 
| 298 | 
            +
                'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
         | 
| 299 | 
            +
                'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
         | 
| 300 | 
            +
                "I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
         | 
| 301 | 
            +
                'What can you tell me about the region indicated by <region> in the image, exactly?',
         | 
| 302 | 
            +
                "I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
         | 
| 303 | 
            +
                'Could you describe the region shown as <region> in the picture in great detail, please?',
         | 
| 304 | 
            +
                'What details can you give me about the region outlined by <region> in the photo, please?',
         | 
| 305 | 
            +
                'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
         | 
| 306 | 
            +
                'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
         | 
| 307 | 
            +
                "I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
         | 
| 308 | 
            +
                'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
         | 
| 309 | 
            +
                'Please describe the region <region> in the image in detail.',
         | 
| 310 | 
            +
                'Can you offer a thorough analysis of the region <region> in the image?',
         | 
| 311 | 
            +
                'Could you elaborate on the region highlighted by <region> in the picture provided?',
         | 
| 312 | 
            +
                'Please share more information about the zone emphasized with <region> in the photo.',
         | 
| 313 | 
            +
                'What insights can you give about the area denoted by <region> in the image presented?',
         | 
| 314 | 
            +
                'Can you share a comprehensive rundown of the region denoted by <region> in the presented image?',
         | 
| 315 | 
            +
                "I'd like to know more about the region highlighted by <region> in the picture provided.",
         | 
| 316 | 
            +
                'Work through the important details of the area <region> in the image.',
         | 
| 317 | 
            +
                'Illustrate the area represented by <region> through a descriptive explanation.',
         | 
| 318 | 
            +
                'Examine the region <region> closely and share its details.'
         | 
| 319 | 
            +
            ]
         | 
| 320 | 
            +
             | 
| 321 | 
            +
            class OspreyDescriptionDataset(OspreyDataset):
         | 
| 322 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 323 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 324 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 325 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                VP_START_TOKEN = '<vp>'
         | 
| 328 | 
            +
                VP_END_TOKEN = '</vp>'
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                LIMIT=''
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 333 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 334 | 
            +
                def __init__(self,
         | 
| 335 | 
            +
                             image_folder,
         | 
| 336 | 
            +
                             data_path=None,
         | 
| 337 | 
            +
                             tokenizer=None,
         | 
| 338 | 
            +
                             max_length=8196,
         | 
| 339 | 
            +
                             special_tokens=None,
         | 
| 340 | 
            +
                             template_map_fn=None,
         | 
| 341 | 
            +
                             extra_image_processor=None,
         | 
| 342 | 
            +
                             lazy=True,
         | 
| 343 | 
            +
                             repeats=1,
         | 
| 344 | 
            +
                             single_image_mode=False,
         | 
| 345 | 
            +
                ):
         | 
| 346 | 
            +
                    super(OspreyDescriptionDataset, self).__init__(
         | 
| 347 | 
            +
                        image_folder=image_folder,
         | 
| 348 | 
            +
                        data_path=data_path,
         | 
| 349 | 
            +
                        tokenizer=tokenizer,
         | 
| 350 | 
            +
                        max_length=max_length,
         | 
| 351 | 
            +
                        special_tokens=special_tokens,
         | 
| 352 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 353 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 354 | 
            +
                        lazy=lazy,
         | 
| 355 | 
            +
                        repeats=repeats,
         | 
| 356 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 357 | 
            +
                    )
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                def dataset_map_fn(self, data_dict):
         | 
| 360 | 
            +
                    file_name = data_dict['file_name'] # image file name
         | 
| 361 | 
            +
                    descriptions = data_dict['description']
         | 
| 362 | 
            +
                    masks = [anno["segmentation"] for anno in data_dict["annotation"]]
         | 
| 363 | 
            +
                    height = data_dict['height']
         | 
| 364 | 
            +
                    width = data_dict['width']
         | 
| 365 | 
            +
                    _ret = {}
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    _ret['image'] = file_name
         | 
| 368 | 
            +
                    _ret['height'] = height
         | 
| 369 | 
            +
                    _ret['width'] = width
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    masks = self.decode_mask(masks, height, width)
         | 
| 372 | 
            +
                    masks, region_pixels = self._get_region_infos(masks)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    if masks is None:
         | 
| 375 | 
            +
                        return None
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    conversations = self._process_conversation(descriptions, len(masks), region_pixels)
         | 
| 378 | 
            +
                    _ret['conversation'] = conversations
         | 
| 379 | 
            +
                    _ret['prompt_masks'] = masks
         | 
| 380 | 
            +
                    return _ret
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                def _process_conversation(self, descriptions, n_regions, region_pixels):
         | 
| 383 | 
            +
                    start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
         | 
| 384 | 
            +
                    for i in range(n_regions):
         | 
| 385 | 
            +
                        start_region_str = start_region_str + \
         | 
| 386 | 
            +
                                           f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
         | 
| 387 | 
            +
                        if i == n_regions - 1:
         | 
| 388 | 
            +
                            start_region_str = start_region_str + '.\n'
         | 
| 389 | 
            +
                        else:
         | 
| 390 | 
            +
                            start_region_str = start_region_str + ', '
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    converations = []
         | 
| 393 | 
            +
                    for i, item in enumerate(descriptions):
         | 
| 394 | 
            +
                        question = random.choice(DETAILED_QUESTIONS).strip().replace('<region>', f"region{i+1}") + self.LIMIT
         | 
| 395 | 
            +
                        answer = item.replace('<', '').replace('>', '')
         | 
| 396 | 
            +
                        # first conv process
         | 
| 397 | 
            +
                        if i == 0:
         | 
| 398 | 
            +
                            question = start_region_str + question
         | 
| 399 | 
            +
                        converations.append({'from': 'human', 'value': question})
         | 
| 400 | 
            +
                        converations.append({'from': 'gpt', 'value': answer})
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    messages = converations
         | 
| 403 | 
            +
                    input = ''
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    conversation = []
         | 
| 406 | 
            +
                    while messages and messages[0]['from'] == 'gpt':
         | 
| 407 | 
            +
                        # Skip the first one if it is from gpt
         | 
| 408 | 
            +
                        messages = messages[1:]
         | 
| 409 | 
            +
                    for msg in messages:
         | 
| 410 | 
            +
                        if msg['from'] == 'human':
         | 
| 411 | 
            +
                            if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 412 | 
            +
                                msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 413 | 
            +
                                                                    '').strip()
         | 
| 414 | 
            +
                                msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 415 | 
            +
                                msg['value'] = msg['value'].strip()
         | 
| 416 | 
            +
                            input += msg['value']
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 419 | 
            +
                            conversation.append({'input': input, 'output': msg['value']})
         | 
| 420 | 
            +
                            input = ''
         | 
| 421 | 
            +
                        else:
         | 
| 422 | 
            +
                            raise NotImplementedError
         | 
| 423 | 
            +
                    return conversation
         | 
| 424 | 
            +
             | 
| 425 | 
            +
             | 
| 426 | 
            +
            class OspreyShortDescriptionDataset(OspreyDataset):
         | 
| 427 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 428 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 429 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 430 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                VP_START_TOKEN = '<vp>'
         | 
| 433 | 
            +
                VP_END_TOKEN = '</vp>'
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                LIMIT = ' Answer the question using a single word or phrase.'
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 438 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                def __init__(self,
         | 
| 441 | 
            +
                             image_folder,
         | 
| 442 | 
            +
                             data_path=None,
         | 
| 443 | 
            +
                             tokenizer=None,
         | 
| 444 | 
            +
                             max_length=8196,
         | 
| 445 | 
            +
                             special_tokens=None,
         | 
| 446 | 
            +
                             template_map_fn=None,
         | 
| 447 | 
            +
                             extra_image_processor=None,
         | 
| 448 | 
            +
                             lazy=True,
         | 
| 449 | 
            +
                             repeats=1,
         | 
| 450 | 
            +
                             single_image_mode=False,
         | 
| 451 | 
            +
                             ):
         | 
| 452 | 
            +
                    super(OspreyShortDescriptionDataset, self).__init__(
         | 
| 453 | 
            +
                        image_folder=image_folder,
         | 
| 454 | 
            +
                        data_path=data_path,
         | 
| 455 | 
            +
                        tokenizer=tokenizer,
         | 
| 456 | 
            +
                        max_length=max_length,
         | 
| 457 | 
            +
                        special_tokens=special_tokens,
         | 
| 458 | 
            +
                        template_map_fn=template_map_fn,
         | 
| 459 | 
            +
                        extra_image_processor=extra_image_processor,
         | 
| 460 | 
            +
                        lazy=lazy,
         | 
| 461 | 
            +
                        repeats=repeats,
         | 
| 462 | 
            +
                        single_image_mode=single_image_mode,
         | 
| 463 | 
            +
                    )
         | 
    	
        projects/llava_sam2/datasets/ReSAM2_Dataset.py
    ADDED
    
    | @@ -0,0 +1,489 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 5 | 
            +
            from datasets import DatasetDict, load_from_disk
         | 
| 6 | 
            +
            from mmengine import print_log
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
            from torch.utils.data import Dataset
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from xtuner.registry import BUILDER
         | 
| 12 | 
            +
            from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
         | 
| 13 | 
            +
            import copy
         | 
| 14 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 15 | 
            +
            import json
         | 
| 16 | 
            +
            import random
         | 
| 17 | 
            +
            import pycocotools.mask as maskUtils
         | 
| 18 | 
            +
            import cv2
         | 
| 19 | 
            +
            import torchvision.transforms as T
         | 
| 20 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            SEG_QUESTIONS = [
         | 
| 23 | 
            +
                "Please segment the object according to the description: {class_name}",
         | 
| 24 | 
            +
            ]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            SEG_QUESTIONS_SHORT = [
         | 
| 27 | 
            +
                "Can you segment the {class_name} in this image?",
         | 
| 28 | 
            +
                "Please segment {class_name} in this image.",
         | 
| 29 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask.",
         | 
| 30 | 
            +
                "What is {class_name} in this image? Please output segmentation mask.",
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                "Can you segment the {class_name} in this image",
         | 
| 33 | 
            +
                "Please segment {class_name} in this image",
         | 
| 34 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask",
         | 
| 35 | 
            +
                "What is {class_name} in this image? Please output segmentation mask",
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image?",
         | 
| 38 | 
            +
                "Please identify and segment the {class_name} in this image.",
         | 
| 39 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
         | 
| 40 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask?",
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image",
         | 
| 43 | 
            +
                "Please identify and segment the {class_name} in this image",
         | 
| 44 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask",
         | 
| 45 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask",
         | 
| 46 | 
            +
            ]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            ANSWER_LIST = [
         | 
| 49 | 
            +
                "It is [SEG].",
         | 
| 50 | 
            +
                "Sure, [SEG].",
         | 
| 51 | 
            +
                "Sure, it is [SEG].",
         | 
| 52 | 
            +
                "Sure, the segmentation result is [SEG].",
         | 
| 53 | 
            +
                "[SEG].",
         | 
| 54 | 
            +
            ]
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            class VideoSAM2Dataset(Dataset):
         | 
| 57 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 58 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 59 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 60 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 61 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
         | 
| 64 | 
            +
                FAST_IMG_START_TOKEN = '<fast_img>'
         | 
| 65 | 
            +
                FAST_IMG_END_TOKEN = '</fast_img>'
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def __init__(self,
         | 
| 68 | 
            +
                             sam2_folder,
         | 
| 69 | 
            +
                             expression_file,
         | 
| 70 | 
            +
                             extra_image_processor=None,
         | 
| 71 | 
            +
                             tokenizer=None,
         | 
| 72 | 
            +
                             select_number=5,
         | 
| 73 | 
            +
                             sampled_frames=5,
         | 
| 74 | 
            +
                             offline_processed_text_folder=None,
         | 
| 75 | 
            +
                             template_map_fn=None,
         | 
| 76 | 
            +
                             max_length=8196,
         | 
| 77 | 
            +
                             lazy=True,
         | 
| 78 | 
            +
                             repeats=1,
         | 
| 79 | 
            +
                             special_tokens=None,
         | 
| 80 | 
            +
                             use_fast=False,
         | 
| 81 | 
            +
                             n_fast_images=50,
         | 
| 82 | 
            +
                             fast_pool_size=4,
         | 
| 83 | 
            +
                             mode='long',
         | 
| 84 | 
            +
                             frame_contiguous_sample=False,
         | 
| 85 | 
            +
                ):
         | 
| 86 | 
            +
                    assert mode in ['long', 'long_short', 'short']
         | 
| 87 | 
            +
                    self.mode = mode
         | 
| 88 | 
            +
                    self.cur_mode = mode
         | 
| 89 | 
            +
                    assert lazy is True
         | 
| 90 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 91 | 
            +
                    self.select_number = select_number
         | 
| 92 | 
            +
                    self.sampled_frames = sampled_frames
         | 
| 93 | 
            +
                    assert offline_processed_text_folder or (expression_file and tokenizer)
         | 
| 94 | 
            +
                    self.lazy = lazy
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.max_length = max_length
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 99 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 100 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 101 | 
            +
                        del self.template_map_fn['type']
         | 
| 102 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if offline_processed_text_folder and expression_file:
         | 
| 105 | 
            +
                        print_log(
         | 
| 106 | 
            +
                            'Both `offline_processed_text_folder` and '
         | 
| 107 | 
            +
                            '`data_path` are set, and we load dataset from'
         | 
| 108 | 
            +
                            '`offline_processed_text_folder` '
         | 
| 109 | 
            +
                            f'({offline_processed_text_folder})',
         | 
| 110 | 
            +
                            logger='current',
         | 
| 111 | 
            +
                            level=logging.WARNING)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if offline_processed_text_folder is not None:
         | 
| 114 | 
            +
                        raise NotImplementedError
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        video_ids, anno_dict = self.json_file_preprocess(expression_file)
         | 
| 117 | 
            +
                        if self.lazy:
         | 
| 118 | 
            +
                            self.video_ids = video_ids
         | 
| 119 | 
            +
                            self.anno_dict = anno_dict
         | 
| 120 | 
            +
                        else:
         | 
| 121 | 
            +
                            raise NotImplementedError
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.sam2_folder = sam2_folder
         | 
| 124 | 
            +
                    if extra_image_processor is not None:
         | 
| 125 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 126 | 
            +
                    self.down_ratio = 1
         | 
| 127 | 
            +
                    self.repeats = repeats
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    self._system = ''
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 132 | 
            +
                    self.image_size = 448
         | 
| 133 | 
            +
                    patch_size = 14
         | 
| 134 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    self.transformer = T.Compose([
         | 
| 137 | 
            +
                        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 138 | 
            +
                        T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 139 | 
            +
                        T.ToTensor(),
         | 
| 140 | 
            +
                        T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 141 | 
            +
                    ])
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    if special_tokens is not None:
         | 
| 144 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    self.use_fast = use_fast
         | 
| 147 | 
            +
                    self.n_fast_images = n_fast_images
         | 
| 148 | 
            +
                    self.fast_pool_size = fast_pool_size
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    self.frame_contiguous_sample = frame_contiguous_sample
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # for visualization debug
         | 
| 153 | 
            +
                    self.save_folder = './work_dirs/video_debug/'
         | 
| 154 | 
            +
                    self.cur_number = 0
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    print("Video res dataset (ref-sam2), include {} items.".format(len(self.video_ids)))
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def __len__(self):
         | 
| 159 | 
            +
                    return len(self.video_ids) * self.repeats
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                @property
         | 
| 162 | 
            +
                def modality_length(self):
         | 
| 163 | 
            +
                    length_list = []
         | 
| 164 | 
            +
                    for data_dict in self.video_ids:
         | 
| 165 | 
            +
                        cur_len = 20000
         | 
| 166 | 
            +
                        length_list.append(cur_len)
         | 
| 167 | 
            +
                    return length_list
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def real_len(self):
         | 
| 170 | 
            +
                    return len(self.video_ids)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def json_file_preprocess(self, expression_file):
         | 
| 173 | 
            +
                    # prepare expression annotation files
         | 
| 174 | 
            +
                    with open(expression_file, 'r') as f:
         | 
| 175 | 
            +
                        expression_datas = json.load(f)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    video_ids = list(expression_datas.keys())
         | 
| 178 | 
            +
                    return video_ids, expression_datas
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def dataset_map_fn(self, objects_expression_infos, n_frames, n_fast_frames=0):
         | 
| 181 | 
            +
                    # prepare text
         | 
| 182 | 
            +
                    if self.mode == 'long':
         | 
| 183 | 
            +
                        expressions = [object_info['formated'] for object_info in objects_expression_infos]
         | 
| 184 | 
            +
                        self.cur_mode = self.mode
         | 
| 185 | 
            +
                    elif self.mode == 'short':
         | 
| 186 | 
            +
                        expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps'])-1)] for object_info in objects_expression_infos]
         | 
| 187 | 
            +
                        self.cur_mode = self.mode
         | 
| 188 | 
            +
                    else:
         | 
| 189 | 
            +
                        if random.random() < 0.5:
         | 
| 190 | 
            +
                            expressions = [object_info['formated'] for object_info in objects_expression_infos]
         | 
| 191 | 
            +
                            self.cur_mode = 'long'
         | 
| 192 | 
            +
                        else:
         | 
| 193 | 
            +
                            expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps']) - 1)] for
         | 
| 194 | 
            +
                                           object_info in objects_expression_infos]
         | 
| 195 | 
            +
                            self.cur_mode = 'short'
         | 
| 196 | 
            +
                    text_dict = self.prepare_text(n_frames, expressions, num_image_tokens=self.patch_token,
         | 
| 197 | 
            +
                                                  n_fast_frames=n_fast_frames)
         | 
| 198 | 
            +
                    ret = {'conversation': text_dict['conversation']}
         | 
| 199 | 
            +
                    return ret
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_frames=0):
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    if self.use_fast:
         | 
| 204 | 
            +
                        fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
         | 
| 205 | 
            +
                                      f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_frames * self.fast_pool_size * self.fast_pool_size}' \
         | 
| 206 | 
            +
                                      f'{self.FAST_IMG_END_TOKEN}' + '\n'
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        fast_frame_token_str = ''
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    frame_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 211 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 212 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    questions = []
         | 
| 215 | 
            +
                    answers = []
         | 
| 216 | 
            +
                    for i, exp in enumerate(expressions):
         | 
| 217 | 
            +
                        if self.cur_mode == 'short':
         | 
| 218 | 
            +
                            question_template = random.choice(SEG_QUESTIONS_SHORT)
         | 
| 219 | 
            +
                            exp = exp.replace("A ", '')
         | 
| 220 | 
            +
                        else:
         | 
| 221 | 
            +
                            question_template = random.choice(SEG_QUESTIONS)
         | 
| 222 | 
            +
                        questions.append(question_template.format(class_name=exp))
         | 
| 223 | 
            +
                        answers.append(random.choice(ANSWER_LIST))
         | 
| 224 | 
            +
                    qa_list = []
         | 
| 225 | 
            +
                    for i, (question, answer) in enumerate(zip(questions, answers)):
         | 
| 226 | 
            +
                        if i == 0:
         | 
| 227 | 
            +
                            frame_tokens = frame_token_str + '\n'
         | 
| 228 | 
            +
                            # frame_tokens = '=' + ' '
         | 
| 229 | 
            +
                            frame_tokens = frame_tokens * n_frames
         | 
| 230 | 
            +
                            frame_tokens = frame_tokens.strip()
         | 
| 231 | 
            +
                            frame_tokens = fast_frame_token_str + frame_tokens
         | 
| 232 | 
            +
                            qa_list.append(
         | 
| 233 | 
            +
                                {'from': 'human', 'value': frame_tokens + question}
         | 
| 234 | 
            +
                            )
         | 
| 235 | 
            +
                        else:
         | 
| 236 | 
            +
                            qa_list.append(
         | 
| 237 | 
            +
                                {'from': 'human', 'value': question}
         | 
| 238 | 
            +
                            )
         | 
| 239 | 
            +
                        qa_list.append(
         | 
| 240 | 
            +
                            {'from': 'gpt', 'value': answer}
         | 
| 241 | 
            +
                        )
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    input = ''
         | 
| 244 | 
            +
                    conversation = []
         | 
| 245 | 
            +
                    for msg in qa_list:
         | 
| 246 | 
            +
                        if msg['from'] == 'human':
         | 
| 247 | 
            +
                            input += msg['value']
         | 
| 248 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 249 | 
            +
                            conversation.append({'input': input, 'output': msg['value']})
         | 
| 250 | 
            +
                            input = ''
         | 
| 251 | 
            +
                        else:
         | 
| 252 | 
            +
                            raise NotImplementedError
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                    # add system information
         | 
| 255 | 
            +
                    conversation[0].update({'system': self._system})
         | 
| 256 | 
            +
                    return {'conversation': conversation}
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def __getitem__(self, index):
         | 
| 259 | 
            +
                    index = index % self.real_len()
         | 
| 260 | 
            +
                    video_id = self.video_ids[index]
         | 
| 261 | 
            +
                    expression_dict = self.anno_dict[video_id]
         | 
| 262 | 
            +
                    object_ids = list(expression_dict['objects'].keys())
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    video_path = os.path.join(self.sam2_folder, expression_dict['video_path'])
         | 
| 265 | 
            +
                    anno_path = os.path.join(self.sam2_folder, expression_dict['anno_path'])
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    video_frames = get_video_frames(video_path)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    if self.use_fast:
         | 
| 270 | 
            +
                        # sample fast branch
         | 
| 271 | 
            +
                        fast_interval = len(video_frames) / (self.n_fast_images + 1e-4)
         | 
| 272 | 
            +
                        sampled_fast_frame_idxs = [min(int(i * fast_interval), len(video_frames) - 1) for i in range(self.n_fast_images)]
         | 
| 273 | 
            +
                        fast_video_frames = [video_frames[_idx] for _idx in sampled_fast_frame_idxs]
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        fast_video_frames = None
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    video_frames = video_frames[::4]
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    # mask annotation
         | 
| 280 | 
            +
                    with open(anno_path, 'r') as f:
         | 
| 281 | 
            +
                        mask_data = json.load(f)
         | 
| 282 | 
            +
                    masklents = decode_masklet(mask_data['masklet'])
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    n_frames = len(masklents)
         | 
| 285 | 
            +
                    n_objects = len(object_ids)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # sample object
         | 
| 288 | 
            +
                    if n_objects > self.select_number:
         | 
| 289 | 
            +
                        selected_indexes = np.random.choice(n_objects, self.select_number)
         | 
| 290 | 
            +
                    else:
         | 
| 291 | 
            +
                        selected_indexes = np.random.choice(n_objects, self.select_number, replace=True)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    selected_object_ids = [object_ids[_idx] for _idx in selected_indexes]
         | 
| 294 | 
            +
                    objects_expression_infos = [expression_dict['objects'][_idx] for _idx in selected_object_ids]
         | 
| 295 | 
            +
                    _masklents = []
         | 
| 296 | 
            +
                    for _mask in masklents:
         | 
| 297 | 
            +
                        _mask_selected = []
         | 
| 298 | 
            +
                        for _idx in selected_object_ids:
         | 
| 299 | 
            +
                            _mask_selected.append(_mask[:, :, int(_idx)])
         | 
| 300 | 
            +
                        _mask_selected = np.stack(_mask_selected, axis=2)
         | 
| 301 | 
            +
                        _masklents.append(_mask_selected)
         | 
| 302 | 
            +
                    masklents = _masklents
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    # sample video frames
         | 
| 305 | 
            +
                    # prepare images, random select k frames
         | 
| 306 | 
            +
                    if n_frames > self.sampled_frames + 1:
         | 
| 307 | 
            +
                        if self.frame_contiguous_sample and random.random() < 0.5:
         | 
| 308 | 
            +
                            # do contiguous sample
         | 
| 309 | 
            +
                            selected_start_frame = np.random.choice(n_frames - self.sampled_frames, 1, replace=False)
         | 
| 310 | 
            +
                            selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(self.sampled_frames)]
         | 
| 311 | 
            +
                        else:
         | 
| 312 | 
            +
                            selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=False)
         | 
| 313 | 
            +
                    else:
         | 
| 314 | 
            +
                        selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=True)
         | 
| 315 | 
            +
                    selected_frame_indexes.sort()
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    video_frames = [video_frames[_idx] for _idx in selected_frame_indexes]
         | 
| 318 | 
            +
                    masklents = [masklents[_idx] for _idx in selected_frame_indexes]
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    data_dict = self.dataset_map_fn(objects_expression_infos, len(video_frames), n_fast_frames=self.n_fast_images)
         | 
| 321 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 322 | 
            +
                    data_dict.update(result)
         | 
| 323 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
         | 
| 324 | 
            +
                    data_dict.update(result)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    pixel_values = []
         | 
| 327 | 
            +
                    extra_pixel_values = []
         | 
| 328 | 
            +
                    for frame in video_frames:
         | 
| 329 | 
            +
                        frame = frame[:, :, ::-1]
         | 
| 330 | 
            +
                        frame_image = Image.fromarray(frame).convert('RGB')
         | 
| 331 | 
            +
                        ori_width, ori_height = frame_image.size
         | 
| 332 | 
            +
                        if self.extra_image_processor is not None:
         | 
| 333 | 
            +
                            g_image = np.array(frame_image)  # for grounding
         | 
| 334 | 
            +
                            g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 335 | 
            +
                            g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 336 | 
            +
                            extra_pixel_values.append(g_pixel_values)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                        frame_image = self.transformer(frame_image)
         | 
| 339 | 
            +
                        pixel_values.append(frame_image)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    pixel_values = torch.stack(pixel_values, dim=0)  # (n_f, 3, h, w)
         | 
| 342 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 343 | 
            +
                    if self.extra_image_processor is not None:
         | 
| 344 | 
            +
                        data_dict['g_pixel_values'] = extra_pixel_values
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    # for fast branch
         | 
| 347 | 
            +
                    if self.use_fast:
         | 
| 348 | 
            +
                        fast_pixel_values = []
         | 
| 349 | 
            +
                        for frame_image in fast_video_frames:
         | 
| 350 | 
            +
                            frame = frame_image[:, :, ::-1]
         | 
| 351 | 
            +
                            frame_image = Image.fromarray(frame).convert('RGB')
         | 
| 352 | 
            +
                            ori_width, ori_height = frame_image.size
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                            frame_image = self.transformer(frame_image)
         | 
| 355 | 
            +
                            fast_pixel_values.append(frame_image)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                        fast_pixel_values = torch.stack(fast_pixel_values, dim=0)  # (n_f, 3, h, w)
         | 
| 358 | 
            +
                        data_dict['fast_pixel_values'] = fast_pixel_values
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # process and get masks
         | 
| 361 | 
            +
                    masklents = np.stack(masklents, axis=0)  # (n_frames, h, w, n_obj)
         | 
| 362 | 
            +
                    masklents = torch.from_numpy(masklents).permute(3, 0, 1, 2)
         | 
| 363 | 
            +
                    masklents = masklents.flatten(0, 1)
         | 
| 364 | 
            +
                    # print('sam2-mask_shape:', masklents.shape)
         | 
| 365 | 
            +
                    # print('sam2-pixel_values:', data_dict['pixel_values'].shape)
         | 
| 366 | 
            +
                    # print('sam2-g_pixel_values:', len(data_dict['g_pixel_values']), ', ', data_dict['g_pixel_values'][0].shape)
         | 
| 367 | 
            +
                    data_dict['masks'] = masklents
         | 
| 368 | 
            +
                    data_dict['type'] = 'video'
         | 
| 369 | 
            +
                    return data_dict
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                def visualization_debug(self, data_dict):
         | 
| 372 | 
            +
                    save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
         | 
| 373 | 
            +
                    if not os.path.exists(save_folder):
         | 
| 374 | 
            +
                        os.mkdir(save_folder)
         | 
| 375 | 
            +
                    self.cur_number += 1
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    # images
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                    show_images = []
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                    pixel_values = data_dict['pixel_values']
         | 
| 382 | 
            +
                    save_folder_image = os.path.join(save_folder, 'image')
         | 
| 383 | 
            +
                    if not os.path.exists(save_folder_image):
         | 
| 384 | 
            +
                        os.mkdir(save_folder_image)
         | 
| 385 | 
            +
                    for i_image, image_pixel_value in enumerate(pixel_values):
         | 
| 386 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 387 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] * 0.2686
         | 
| 388 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] * 0.2613
         | 
| 389 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] * 0.2757
         | 
| 390 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] + 0.4814
         | 
| 391 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] + 0.4578
         | 
| 392 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] + 0.4082
         | 
| 393 | 
            +
                        image_pixel_value = image_pixel_value * 255
         | 
| 394 | 
            +
                        image_pixel_value = image_pixel_value.permute(1, 2, 0)
         | 
| 395 | 
            +
                        image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
         | 
| 396 | 
            +
                        # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
         | 
| 397 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 398 | 
            +
                        show_images.append(image_pixel_value)
         | 
| 399 | 
            +
                        cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    # text
         | 
| 402 | 
            +
                    input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
         | 
| 403 | 
            +
                    with open(os.path.join(save_folder, 'text.json'), 'w') as f:
         | 
| 404 | 
            +
                        json.dump([input_text], f)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    # masks
         | 
| 407 | 
            +
                    save_folder_mask = os.path.join(save_folder, 'mask')
         | 
| 408 | 
            +
                    if not os.path.exists(save_folder_mask):
         | 
| 409 | 
            +
                        os.mkdir(save_folder_mask)
         | 
| 410 | 
            +
                    n_frames = len(pixel_values)
         | 
| 411 | 
            +
                    masks = data_dict['masks']
         | 
| 412 | 
            +
                    _, h, w = masks.shape
         | 
| 413 | 
            +
                    masks = masks.reshape(-1, n_frames, h, w)
         | 
| 414 | 
            +
                    for i_obj, obj_masks in enumerate(masks):
         | 
| 415 | 
            +
                        save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
         | 
| 416 | 
            +
                        if not os.path.exists(save_folder_mask_obj_folder):
         | 
| 417 | 
            +
                            os.mkdir(save_folder_mask_obj_folder)
         | 
| 418 | 
            +
                        for i_frame, f_mask in enumerate(obj_masks):
         | 
| 419 | 
            +
                            f_mask = f_mask.numpy()
         | 
| 420 | 
            +
                            f_mask = f_mask * 255
         | 
| 421 | 
            +
                            f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
         | 
| 422 | 
            +
                            f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
         | 
| 423 | 
            +
                            f_mask = f_mask.astype(np.uint8)
         | 
| 424 | 
            +
                            cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
         | 
| 425 | 
            +
                    return
         | 
| 426 | 
            +
             | 
| 427 | 
            +
            def get_video_frames(video_path):
         | 
| 428 | 
            +
                cap = cv2.VideoCapture(video_path)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                if not cap.isOpened():
         | 
| 431 | 
            +
                    print("Error: Cannot open video file.")
         | 
| 432 | 
            +
                    return
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                frames = []
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                frame_id = 0
         | 
| 437 | 
            +
                while True:
         | 
| 438 | 
            +
                    ret, frame = cap.read()
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    if not ret:
         | 
| 441 | 
            +
                        break
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    frames.append(frame)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    frame_id += 1
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                cap.release()
         | 
| 448 | 
            +
                return frames
         | 
| 449 | 
            +
             | 
| 450 | 
            +
             | 
| 451 | 
            +
            def images_to_video(frames, video_name, fps=6):
         | 
| 452 | 
            +
                height, width, layers = frames[0].shape
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                fourcc = cv2.VideoWriter_fourcc(*'mp4v')
         | 
| 455 | 
            +
                video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                for frame in frames:
         | 
| 458 | 
            +
                    video.write(frame)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                # cv2.destroyAllWindows()
         | 
| 461 | 
            +
                video.release()
         | 
| 462 | 
            +
                return
         | 
| 463 | 
            +
             | 
| 464 | 
            +
            def decode_masklet(masklet):
         | 
| 465 | 
            +
                masks = []
         | 
| 466 | 
            +
                for _rle in masklet:
         | 
| 467 | 
            +
                    mask = maskUtils.decode(_rle)
         | 
| 468 | 
            +
                    masks.append(mask)
         | 
| 469 | 
            +
                return masks
         | 
| 470 | 
            +
             | 
| 471 | 
            +
            def draw_mask(image, mask):
         | 
| 472 | 
            +
                obj_mask = mask * 255
         | 
| 473 | 
            +
                obj_mask = np.stack([obj_mask * 1, obj_mask * 0, obj_mask * 0], axis=2)
         | 
| 474 | 
            +
                obj_mask = obj_mask * 0.5 + copy.deepcopy(image) * 0.5
         | 
| 475 | 
            +
                obj_mask = obj_mask.astype(np.uint8)
         | 
| 476 | 
            +
                return obj_mask
         | 
| 477 | 
            +
             | 
| 478 | 
            +
            def add_mask2images(frames, masklets):
         | 
| 479 | 
            +
                show_videos = []
         | 
| 480 | 
            +
                for i_frames, (frame, masks) in enumerate(zip(frames, masklets)):
         | 
| 481 | 
            +
                    if i_frames == 0:
         | 
| 482 | 
            +
                        n_obj = masks.shape[-1]
         | 
| 483 | 
            +
                        for i_obj in range(n_obj):
         | 
| 484 | 
            +
                            show_videos.append([])
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    n_obj = masks.shape[-1]
         | 
| 487 | 
            +
                    for i_obj in range(n_obj):
         | 
| 488 | 
            +
                        show_videos[i_obj].append(draw_mask(copy.deepcopy(frame), masks[:, :, i_obj]))
         | 
| 489 | 
            +
                return show_videos
         | 
    	
        projects/llava_sam2/datasets/ReVOS_Dataset.py
    ADDED
    
    | @@ -0,0 +1,602 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from typing import Literal
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from datasets import Dataset as HFDataset
         | 
| 7 | 
            +
            from datasets import DatasetDict
         | 
| 8 | 
            +
            from mmengine import print_log
         | 
| 9 | 
            +
            from PIL import Image
         | 
| 10 | 
            +
            from torch.utils.data import Dataset
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from xtuner.registry import BUILDER
         | 
| 14 | 
            +
            from xtuner.dataset.huggingface import build_origin_dataset
         | 
| 15 | 
            +
            import copy
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 18 | 
            +
            import json
         | 
| 19 | 
            +
            import random
         | 
| 20 | 
            +
            import pycocotools.mask as maskUtils
         | 
| 21 | 
            +
            import cv2
         | 
| 22 | 
            +
            import torchvision.transforms as T
         | 
| 23 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            SEG_QUESTIONS = [
         | 
| 26 | 
            +
                "Can you segment the {class_name} in this image?",
         | 
| 27 | 
            +
                "Please segment {class_name} in this image.",
         | 
| 28 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask.",
         | 
| 29 | 
            +
                "What is {class_name} in this image? Please output segmentation mask.",
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                "Can you segment the {class_name} in this image",
         | 
| 32 | 
            +
                "Please segment {class_name} in this image",
         | 
| 33 | 
            +
                "What is {class_name} in this image? Please respond with segmentation mask",
         | 
| 34 | 
            +
                "What is {class_name} in this image? Please output segmentation mask",
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image?",
         | 
| 37 | 
            +
                "Please identify and segment the {class_name} in this image.",
         | 
| 38 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask.",
         | 
| 39 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask?",
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                "Could you provide a segmentation mask for the {class_name} in this image",
         | 
| 42 | 
            +
                "Please identify and segment the {class_name} in this image",
         | 
| 43 | 
            +
                "Where is the {class_name} in this picture? Please respond with a segmentation mask",
         | 
| 44 | 
            +
                "Can you highlight the {class_name} in this image with a segmentation mask",
         | 
| 45 | 
            +
            ]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            ANSWER_LIST = [
         | 
| 48 | 
            +
                "It is [SEG].",
         | 
| 49 | 
            +
                "Sure, [SEG].",
         | 
| 50 | 
            +
                "Sure, it is [SEG].",
         | 
| 51 | 
            +
                "Sure, the segmentation result is [SEG].",
         | 
| 52 | 
            +
                "[SEG].",
         | 
| 53 | 
            +
            ]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            class VideoReVOSDataset(Dataset):
         | 
| 56 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 57 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 58 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 59 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 60 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
         | 
| 63 | 
            +
                FAST_IMG_START_TOKEN = '<fast_img>'
         | 
| 64 | 
            +
                FAST_IMG_END_TOKEN = '</fast_img>'
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def __init__(self,
         | 
| 67 | 
            +
                             image_folder,
         | 
| 68 | 
            +
                             expression_file,
         | 
| 69 | 
            +
                             mask_file,
         | 
| 70 | 
            +
                             extra_image_processor=None,
         | 
| 71 | 
            +
                             tokenizer=None,
         | 
| 72 | 
            +
                             select_number=5,
         | 
| 73 | 
            +
                             sampled_frames=10,
         | 
| 74 | 
            +
                             offline_processed_text_folder=None,
         | 
| 75 | 
            +
                             template_map_fn=None,
         | 
| 76 | 
            +
                             max_length=2048,
         | 
| 77 | 
            +
                             lazy=True,
         | 
| 78 | 
            +
                             repeats=1,
         | 
| 79 | 
            +
                             special_tokens=None,
         | 
| 80 | 
            +
                             frame_contiguous_sample=False,
         | 
| 81 | 
            +
                             use_fast=False,
         | 
| 82 | 
            +
                             arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
         | 
| 83 | 
            +
                             preprocessor=None,
         | 
| 84 | 
            +
                             # only work if use_fast = True
         | 
| 85 | 
            +
                             n_fast_images=50,
         | 
| 86 | 
            +
                             fast_pool_size=4,
         | 
| 87 | 
            +
                             fast_token_after_question=False,
         | 
| 88 | 
            +
                ):
         | 
| 89 | 
            +
                    assert lazy is True
         | 
| 90 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 91 | 
            +
                    self.select_number = select_number
         | 
| 92 | 
            +
                    self.sampled_frames = sampled_frames
         | 
| 93 | 
            +
                    assert offline_processed_text_folder or (expression_file and tokenizer)
         | 
| 94 | 
            +
                    self.lazy = lazy
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.max_length = max_length
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.template_map_fn = template_map_fn
         | 
| 99 | 
            +
                    if isinstance(self.template_map_fn, dict) and self.lazy:
         | 
| 100 | 
            +
                        _type = self.template_map_fn['type']
         | 
| 101 | 
            +
                        del self.template_map_fn['type']
         | 
| 102 | 
            +
                        self.template_map_fn = _type(**self.template_map_fn)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if offline_processed_text_folder and expression_file:
         | 
| 105 | 
            +
                        print_log(
         | 
| 106 | 
            +
                            'Both `offline_processed_text_folder` and '
         | 
| 107 | 
            +
                            '`data_path` are set, and we load dataset from'
         | 
| 108 | 
            +
                            '`offline_processed_text_folder` '
         | 
| 109 | 
            +
                            f'({offline_processed_text_folder})',
         | 
| 110 | 
            +
                            logger='current',
         | 
| 111 | 
            +
                            level=logging.WARNING)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    self.arch_type = arch_type
         | 
| 114 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 115 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
         | 
| 116 | 
            +
                        self.IMG_START_TOKEN = '<|vision_start|>'
         | 
| 117 | 
            +
                        self.IMG_END_TOKEN = '<|vision_end|>'
         | 
| 118 | 
            +
                    elif self.arch_type == 'llava':
         | 
| 119 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<image>'
         | 
| 120 | 
            +
                        self.IMG_START_TOKEN = ''
         | 
| 121 | 
            +
                        self.IMG_END_TOKEN = ''
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
                    if offline_processed_text_folder is not None:
         | 
| 125 | 
            +
                        raise NotImplementedError
         | 
| 126 | 
            +
                    else:
         | 
| 127 | 
            +
                        vid2metaid, metas, mask_dict = self.json_file_preprocess(expression_file, mask_file)
         | 
| 128 | 
            +
                        self.vid2metaid = vid2metaid
         | 
| 129 | 
            +
                        self.videos = list(self.vid2metaid.keys())
         | 
| 130 | 
            +
                        self.mask_dict = mask_dict
         | 
| 131 | 
            +
                        self.json_datas = metas
         | 
| 132 | 
            +
                        json_datas = metas
         | 
| 133 | 
            +
                        json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
         | 
| 134 | 
            +
                        if self.lazy:
         | 
| 135 | 
            +
                            self.text_data = build_origin_dataset(json_data, 'train')
         | 
| 136 | 
            +
                        else:
         | 
| 137 | 
            +
                            raise NotImplementedError
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    self.image_folder = image_folder
         | 
| 140 | 
            +
                    if extra_image_processor is not None:
         | 
| 141 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 142 | 
            +
                    self.down_ratio = 1
         | 
| 143 | 
            +
                    self.repeats = repeats
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self._system = ''
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 148 | 
            +
                    if self.arch_type == 'llava':
         | 
| 149 | 
            +
                        self.downsample_ratio = 1
         | 
| 150 | 
            +
                    self.image_size = 448
         | 
| 151 | 
            +
                    if self.arch_type == 'llava':
         | 
| 152 | 
            +
                        self.image_size = 336
         | 
| 153 | 
            +
                    patch_size = 14
         | 
| 154 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 155 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 156 | 
            +
                        self.patch_token = 1
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    if preprocessor is None:
         | 
| 159 | 
            +
                        self.transformer = T.Compose([
         | 
| 160 | 
            +
                            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 161 | 
            +
                            T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 162 | 
            +
                            T.ToTensor(),
         | 
| 163 | 
            +
                            T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 164 | 
            +
                        ])
         | 
| 165 | 
            +
                        self.preprocessor = None
         | 
| 166 | 
            +
                    else:
         | 
| 167 | 
            +
                        self.transformer = None
         | 
| 168 | 
            +
                        self.preprocessor = BUILDER.build(preprocessor)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    if special_tokens is not None:
         | 
| 171 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    self.use_fast = use_fast
         | 
| 174 | 
            +
                    self.n_fast_images = n_fast_images
         | 
| 175 | 
            +
                    self.fast_pool_size = fast_pool_size
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    self.frame_contiguous_sample = frame_contiguous_sample
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # for visualization debug
         | 
| 180 | 
            +
                    self.save_folder = './work_dirs/video_debug/'
         | 
| 181 | 
            +
                    self.cur_number = 0
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    # exist_thr
         | 
| 184 | 
            +
                    self.exist_thr = 8
         | 
| 185 | 
            +
                    self.fast_token_after_question = fast_token_after_question
         | 
| 186 | 
            +
                    if self.fast_token_after_question:
         | 
| 187 | 
            +
                        assert self.use_fast
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    print("Video res dataset, include {} items.".format(len(self.vid2metaid)))
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def __len__(self):
         | 
| 192 | 
            +
                    return len(self.vid2metaid) * self.repeats
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                @property
         | 
| 195 | 
            +
                def modality_length(self):
         | 
| 196 | 
            +
                    length_list = []
         | 
| 197 | 
            +
                    for data_dict in self.vid2metaid:
         | 
| 198 | 
            +
                        cur_len = 10000
         | 
| 199 | 
            +
                        length_list.append(cur_len)
         | 
| 200 | 
            +
                    return length_list
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def real_len(self):
         | 
| 203 | 
            +
                    return len(self.vid2metaid)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def json_file_preprocess(self, expression_file, mask_file):
         | 
| 206 | 
            +
                    # prepare expression annotation files
         | 
| 207 | 
            +
                    with open(expression_file, 'r') as f:
         | 
| 208 | 
            +
                        expression_datas = json.load(f)['videos']
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    metas = []
         | 
| 211 | 
            +
                    anno_count = 0  # serve as anno_id
         | 
| 212 | 
            +
                    vid2metaid = {}
         | 
| 213 | 
            +
                    for vid_name in expression_datas:
         | 
| 214 | 
            +
                        vid_express_data = expression_datas[vid_name]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                        vid_frames = sorted(vid_express_data['frames'])
         | 
| 217 | 
            +
                        vid_len = len(vid_frames)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                        exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
         | 
| 220 | 
            +
                        for exp_id in exp_id_list:
         | 
| 221 | 
            +
                            exp_dict = vid_express_data['expressions'][exp_id]
         | 
| 222 | 
            +
                            meta = {}
         | 
| 223 | 
            +
                            meta['video'] = vid_name
         | 
| 224 | 
            +
                            meta['exp'] = exp_dict['exp']  # str
         | 
| 225 | 
            +
                            meta['mask_anno_id'] = exp_dict['anno_id']
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                            if 'obj_id' in exp_dict.keys():
         | 
| 228 | 
            +
                                meta['obj_id'] = exp_dict['obj_id']
         | 
| 229 | 
            +
                            else:
         | 
| 230 | 
            +
                                meta['obj_id'] = [0, ]  # Ref-Youtube-VOS only has one object per expression
         | 
| 231 | 
            +
                            meta['anno_id'] = [str(anno_count), ]
         | 
| 232 | 
            +
                            anno_count += 1
         | 
| 233 | 
            +
                            meta['frames'] = vid_frames
         | 
| 234 | 
            +
                            meta['exp_id'] = exp_id
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                            meta['length'] = vid_len
         | 
| 237 | 
            +
                            metas.append(meta)
         | 
| 238 | 
            +
                            if vid_name not in vid2metaid.keys():
         | 
| 239 | 
            +
                                vid2metaid[vid_name] = []
         | 
| 240 | 
            +
                            vid2metaid[vid_name].append(len(metas) - 1)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    # process mask annotation files
         | 
| 243 | 
            +
                    with open(mask_file, 'rb') as f:
         | 
| 244 | 
            +
                        mask_dict = json.load(f)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    return vid2metaid, metas, mask_dict
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def create_img_to_refs_mapping(self, refs_train):
         | 
| 249 | 
            +
                    img2refs = {}
         | 
| 250 | 
            +
                    for ref in refs_train:
         | 
| 251 | 
            +
                        img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
         | 
| 252 | 
            +
                    return img2refs
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def decode_mask(self, video_masks, image_size):
         | 
| 255 | 
            +
                    ret_masks = []
         | 
| 256 | 
            +
                    for object_masks in video_masks:
         | 
| 257 | 
            +
                        # None object
         | 
| 258 | 
            +
                        if len(object_masks) == 0:
         | 
| 259 | 
            +
                            if len(ret_masks) != 0:
         | 
| 260 | 
            +
                                _object_masks = ret_masks[0] * 0
         | 
| 261 | 
            +
                            else:
         | 
| 262 | 
            +
                                _object_masks = np.zeros(
         | 
| 263 | 
            +
                                    (self.sampled_frames, image_size[0], image_size[1]), dtype=np.uint8)
         | 
| 264 | 
            +
                        else:
         | 
| 265 | 
            +
                            _object_masks = []
         | 
| 266 | 
            +
                            for i_frame in range(len(object_masks[0])):
         | 
| 267 | 
            +
                                _mask = np.zeros(image_size, dtype=np.uint8)
         | 
| 268 | 
            +
                                for i_anno in range(len(object_masks)):
         | 
| 269 | 
            +
                                    if object_masks[i_anno][i_frame] is None:
         | 
| 270 | 
            +
                                        continue
         | 
| 271 | 
            +
                                    m = maskUtils.decode(object_masks[i_anno][i_frame])
         | 
| 272 | 
            +
                                    if m.ndim == 3:
         | 
| 273 | 
            +
                                        m = m.sum(axis=2).astype(np.uint8)
         | 
| 274 | 
            +
                                    else:
         | 
| 275 | 
            +
                                        m = m.astype(np.uint8)
         | 
| 276 | 
            +
                                    _mask = _mask | m
         | 
| 277 | 
            +
                                _object_masks.append(_mask)
         | 
| 278 | 
            +
                            _object_masks = np.stack(_object_masks, axis=0)
         | 
| 279 | 
            +
                        # if self.pad_image_to_square:
         | 
| 280 | 
            +
                        #     _object_masks = expand2square_mask(_object_masks)
         | 
| 281 | 
            +
                        ret_masks.append(_object_masks)
         | 
| 282 | 
            +
                    _shape = ret_masks[0].shape
         | 
| 283 | 
            +
                    for item in ret_masks:
         | 
| 284 | 
            +
                        if item.shape != _shape:
         | 
| 285 | 
            +
                            print([_ret_mask.shape for _ret_mask in ret_masks])
         | 
| 286 | 
            +
                            return None
         | 
| 287 | 
            +
                    ret_masks = np.stack(ret_masks, axis=0)  # (n_obj, n_frames, h, w)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    ret_masks = torch.from_numpy(ret_masks)
         | 
| 290 | 
            +
                    # ret_masks = F.interpolate(ret_masks, size=(self.image_size // self.down_ratio,
         | 
| 291 | 
            +
                    #                           self.image_size // self.down_ratio), mode='nearest')
         | 
| 292 | 
            +
                    ret_masks = ret_masks.flatten(0, 1)
         | 
| 293 | 
            +
                    return ret_masks
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def dataset_map_fn(self, data_dict, select_k=5):
         | 
| 296 | 
            +
                    images = []
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    len_frames = len(data_dict[0]['frames'])
         | 
| 299 | 
            +
                    for objet_info in data_dict:
         | 
| 300 | 
            +
                        assert len_frames == len(objet_info['frames'])
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # prepare images, random select k frames
         | 
| 303 | 
            +
                    if len_frames > select_k + 1:
         | 
| 304 | 
            +
                        if self.frame_contiguous_sample and random.random() < 0.5:
         | 
| 305 | 
            +
                            # do contiguous sample
         | 
| 306 | 
            +
                            selected_start_frame = np.random.choice(len_frames - select_k, 1, replace=False)
         | 
| 307 | 
            +
                            selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(select_k)]
         | 
| 308 | 
            +
                        else:
         | 
| 309 | 
            +
                            selected_frame_indexes = np.random.choice(len_frames, select_k, replace=False)
         | 
| 310 | 
            +
                    else:
         | 
| 311 | 
            +
                        selected_frame_indexes = np.random.choice(len_frames, select_k, replace=True)
         | 
| 312 | 
            +
                    selected_frame_indexes.sort()
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    if self.use_fast:
         | 
| 315 | 
            +
                        # sample fast branch
         | 
| 316 | 
            +
                        fast_interval = len_frames / (self.n_fast_images + 1e-4)
         | 
| 317 | 
            +
                        sampled_fast_frame_idxs = [min(int(i * fast_interval), len_frames - 1) for i in range(self.n_fast_images)]
         | 
| 318 | 
            +
                        fast_video_frames = []
         | 
| 319 | 
            +
                        for selected_frame_index in sampled_fast_frame_idxs:
         | 
| 320 | 
            +
                            frame_id = data_dict[0]['frames'][selected_frame_index]
         | 
| 321 | 
            +
                            fast_video_frames.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
         | 
| 322 | 
            +
                    else:
         | 
| 323 | 
            +
                        fast_video_frames = None
         | 
| 324 | 
            +
                        sampled_fast_frame_idxs = None
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    for selected_frame_index in selected_frame_indexes:
         | 
| 327 | 
            +
                        frame_id = data_dict[0]['frames'][selected_frame_index]
         | 
| 328 | 
            +
                        images.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # prepare text
         | 
| 331 | 
            +
                    expressions = [object_info['exp'] for object_info in data_dict]
         | 
| 332 | 
            +
                    if self.use_fast:
         | 
| 333 | 
            +
                        text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token,
         | 
| 334 | 
            +
                                                      n_fast_images=len(fast_video_frames),)
         | 
| 335 | 
            +
                    else:
         | 
| 336 | 
            +
                        text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
             | 
| 339 | 
            +
                    # prepare masks
         | 
| 340 | 
            +
                    video_masks = []
         | 
| 341 | 
            +
                    for object_info in data_dict:
         | 
| 342 | 
            +
                        anno_ids = object_info['mask_anno_id']
         | 
| 343 | 
            +
                        # print('anno_ids: ', anno_ids)
         | 
| 344 | 
            +
                        obj_masks = []
         | 
| 345 | 
            +
                        for anno_id in anno_ids:
         | 
| 346 | 
            +
                            anno_id = str(anno_id)
         | 
| 347 | 
            +
                            frames_masks = self.mask_dict[anno_id]
         | 
| 348 | 
            +
                            frames_masks_ = []
         | 
| 349 | 
            +
                            for frame_idx in selected_frame_indexes:
         | 
| 350 | 
            +
                                frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
         | 
| 351 | 
            +
                            obj_masks.append(frames_masks_)
         | 
| 352 | 
            +
                        video_masks.append(obj_masks)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    if self.use_fast:
         | 
| 355 | 
            +
                        fast_video_masks = []
         | 
| 356 | 
            +
                        assert sampled_fast_frame_idxs is not None
         | 
| 357 | 
            +
                        for object_info in data_dict:
         | 
| 358 | 
            +
                            anno_ids = object_info['mask_anno_id']
         | 
| 359 | 
            +
                            obj_masks = []
         | 
| 360 | 
            +
                            for anno_id in anno_ids:
         | 
| 361 | 
            +
                                anno_id = str(anno_id)
         | 
| 362 | 
            +
                                frames_masks = self.mask_dict[anno_id]
         | 
| 363 | 
            +
                                frames_masks_ = []
         | 
| 364 | 
            +
                                for frame_idx in sampled_fast_frame_idxs:
         | 
| 365 | 
            +
                                    frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
         | 
| 366 | 
            +
                                obj_masks.append(frames_masks_)
         | 
| 367 | 
            +
                            fast_video_masks.append(obj_masks)
         | 
| 368 | 
            +
                    else:
         | 
| 369 | 
            +
                        fast_video_masks = None
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    ret = {'images': images, 'video_masks': video_masks, 'conversation': text_dict['conversation'],
         | 
| 372 | 
            +
                           'fast_images': fast_video_frames, 'fast_video_masks': fast_video_masks}
         | 
| 373 | 
            +
                    return ret
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_images=50):
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    if self.use_fast and not self.fast_token_after_question:
         | 
| 378 | 
            +
                        fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
         | 
| 379 | 
            +
                                      f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
         | 
| 380 | 
            +
                                      f'{self.FAST_IMG_END_TOKEN}' + '\n'
         | 
| 381 | 
            +
                    else:
         | 
| 382 | 
            +
                        fast_frame_token_str = ''
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    frame_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 385 | 
            +
                                      f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 386 | 
            +
                                      f'{self.IMG_END_TOKEN}'
         | 
| 387 | 
            +
                    if self.fast_token_after_question:
         | 
| 388 | 
            +
                        assert self.use_fast
         | 
| 389 | 
            +
                        after_question_str = f'{self.FAST_IMG_START_TOKEN}' \
         | 
| 390 | 
            +
                                      f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
         | 
| 391 | 
            +
                                      f'{self.FAST_IMG_END_TOKEN}'
         | 
| 392 | 
            +
                    else:
         | 
| 393 | 
            +
                        after_question_str = ''
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    questions = []
         | 
| 396 | 
            +
                    answers = []
         | 
| 397 | 
            +
                    for i, exp in enumerate(expressions):
         | 
| 398 | 
            +
                        # the exp is a question
         | 
| 399 | 
            +
                        if '?' in exp:
         | 
| 400 | 
            +
                            questions.append(exp)
         | 
| 401 | 
            +
                        else:
         | 
| 402 | 
            +
                            exp = exp.replace('.', '').strip()
         | 
| 403 | 
            +
                            question_template = random.choice(SEG_QUESTIONS)
         | 
| 404 | 
            +
                            questions.append(question_template.format(class_name=exp.lower()))
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                        answers.append(random.choice(ANSWER_LIST))
         | 
| 407 | 
            +
                    qa_list = []
         | 
| 408 | 
            +
                    for i, (question, answer) in enumerate(zip(questions, answers)):
         | 
| 409 | 
            +
                        if i == 0:
         | 
| 410 | 
            +
                            frame_tokens = frame_token_str + '\n'
         | 
| 411 | 
            +
                            # frame_tokens = '=' + ' '
         | 
| 412 | 
            +
                            frame_tokens = frame_tokens * n_frames
         | 
| 413 | 
            +
                            frame_tokens = frame_tokens.strip()
         | 
| 414 | 
            +
                            frame_tokens = fast_frame_token_str + frame_tokens
         | 
| 415 | 
            +
                            qa_list.append(
         | 
| 416 | 
            +
                                {'from': 'human', 'value': frame_tokens + question + after_question_str}
         | 
| 417 | 
            +
                            )
         | 
| 418 | 
            +
                        else:
         | 
| 419 | 
            +
                            qa_list.append(
         | 
| 420 | 
            +
                                {'from': 'human', 'value': question + after_question_str}
         | 
| 421 | 
            +
                            )
         | 
| 422 | 
            +
                        qa_list.append(
         | 
| 423 | 
            +
                            {'from': 'gpt', 'value': answer}
         | 
| 424 | 
            +
                        )
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    input = ''
         | 
| 427 | 
            +
                    conversation = []
         | 
| 428 | 
            +
                    for msg in qa_list:
         | 
| 429 | 
            +
                        if msg['from'] == 'human':
         | 
| 430 | 
            +
                            input += msg['value']
         | 
| 431 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 432 | 
            +
                            conversation.append({'input': input, 'output': msg['value']})
         | 
| 433 | 
            +
                            input = ''
         | 
| 434 | 
            +
                        else:
         | 
| 435 | 
            +
                            raise NotImplementedError
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    # add system information
         | 
| 438 | 
            +
                    conversation[0].update({'system': self._system})
         | 
| 439 | 
            +
                    return {'conversation': conversation}
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                def __getitem__(self, index):
         | 
| 442 | 
            +
                    index = index % self.real_len()
         | 
| 443 | 
            +
                    selected_video_objects = self.vid2metaid[self.videos[index]]
         | 
| 444 | 
            +
                    video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    if len(video_objects_infos) > self.select_number:
         | 
| 447 | 
            +
                        selected_indexes = np.random.choice(len(video_objects_infos), self.select_number)
         | 
| 448 | 
            +
                        video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
         | 
| 449 | 
            +
                    else:
         | 
| 450 | 
            +
                        selected_indexes = np.random.choice(len(video_objects_infos), self.select_number, replace=True)
         | 
| 451 | 
            +
                        video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    data_dict = self.dataset_map_fn(video_objects_infos, select_k=self.sampled_frames)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    assert 'images' in data_dict.keys()
         | 
| 456 | 
            +
                    pixel_values = []
         | 
| 457 | 
            +
                    extra_pixel_values = []
         | 
| 458 | 
            +
                    num_video_tokens = None
         | 
| 459 | 
            +
                    num_frame_tokens = None
         | 
| 460 | 
            +
                    if data_dict.get('images', None) is not None:
         | 
| 461 | 
            +
                        frames_files = data_dict['images']
         | 
| 462 | 
            +
                        frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
         | 
| 463 | 
            +
                        for frame_path in frames_files:
         | 
| 464 | 
            +
                            frame_image = Image.open(frame_path).convert('RGB')
         | 
| 465 | 
            +
                            ori_width, ori_height = frame_image.size
         | 
| 466 | 
            +
                            if self.extra_image_processor is not None:
         | 
| 467 | 
            +
                                g_image = np.array(frame_image)  # for grounding
         | 
| 468 | 
            +
                                g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 469 | 
            +
                                g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 470 | 
            +
                                extra_pixel_values.append(g_pixel_values)
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                            if self.preprocessor is not None:
         | 
| 473 | 
            +
                                pass
         | 
| 474 | 
            +
                            else:
         | 
| 475 | 
            +
                                frame_image = self.transformer(frame_image)
         | 
| 476 | 
            +
                            pixel_values.append(frame_image)
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                        if self.preprocessor is not None:
         | 
| 479 | 
            +
                            if self.arch_type == 'qwen':
         | 
| 480 | 
            +
                                _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
         | 
| 481 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 482 | 
            +
                                _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
         | 
| 483 | 
            +
                                num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
         | 
| 484 | 
            +
                                num_frames = _data_dict['image_grid_thw'].shape[0]
         | 
| 485 | 
            +
                                num_video_tokens = num_frame_tokens * num_frames
         | 
| 486 | 
            +
                            elif self.arch_type == 'llava':
         | 
| 487 | 
            +
                                _data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
         | 
| 488 | 
            +
                                _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
         | 
| 489 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 490 | 
            +
                            else:
         | 
| 491 | 
            +
                                raise NotImplementedError
         | 
| 492 | 
            +
                            data_dict.update(_data_dict)
         | 
| 493 | 
            +
                        else:
         | 
| 494 | 
            +
                            pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
         | 
| 495 | 
            +
                            data_dict['pixel_values'] = pixel_values
         | 
| 496 | 
            +
                        if self.extra_image_processor is not None:
         | 
| 497 | 
            +
                            data_dict['g_pixel_values'] = extra_pixel_values
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                        # process and get masks
         | 
| 500 | 
            +
                        masks = self.decode_mask(data_dict['video_masks'], image_size=(ori_height, ori_width))
         | 
| 501 | 
            +
                        if masks is None:
         | 
| 502 | 
            +
                            return self.__getitem__(random.randint(0, self.real_len()))
         | 
| 503 | 
            +
                        data_dict['masks'] = masks
         | 
| 504 | 
            +
                    else:
         | 
| 505 | 
            +
                        data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
         | 
| 506 | 
            +
                        data_dict['masks'] = None
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                    if num_video_tokens is not None:
         | 
| 509 | 
            +
                        assert self.patch_token == 1
         | 
| 510 | 
            +
                        input_str = data_dict['conversation'][0]['input']
         | 
| 511 | 
            +
                        input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
         | 
| 512 | 
            +
                        assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
         | 
| 513 | 
            +
                        data_dict['conversation'][0]['input'] = input_str
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    result = self.template_map_fn(data_dict)
         | 
| 516 | 
            +
                    data_dict.update(result)
         | 
| 517 | 
            +
                    result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length)
         | 
| 518 | 
            +
                    data_dict.update(result)
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    # for fast branch
         | 
| 521 | 
            +
                    if self.use_fast:
         | 
| 522 | 
            +
                        fast_pixel_values = []
         | 
| 523 | 
            +
                        frames_files = data_dict['fast_images']
         | 
| 524 | 
            +
                        frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
         | 
| 525 | 
            +
                        for frame_path in frames_files:
         | 
| 526 | 
            +
                            frame_image = Image.open(frame_path).convert('RGB')
         | 
| 527 | 
            +
                            ori_width, ori_height = frame_image.size
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                            frame_image = self.transformer(frame_image)
         | 
| 530 | 
            +
                            fast_pixel_values.append(frame_image)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                        fast_pixel_values = torch.stack(fast_pixel_values, dim=0)  # (n_f, 3, h, w)
         | 
| 533 | 
            +
                        data_dict['fast_pixel_values'] = fast_pixel_values
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                        # process and get masks
         | 
| 536 | 
            +
                        masks = self.decode_mask(data_dict['fast_video_masks'], image_size=(ori_height, ori_width))
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                        if masks is None:
         | 
| 539 | 
            +
                            return self.__getitem__(random.randint(0, self.real_len()))
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                        data_dict['fast_exists'] = masks.to(dtype=torch.int).sum(dim=(-2, -1)).ge(self.exist_thr).unsqueeze(-1)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
             | 
| 544 | 
            +
                        del data_dict['fast_video_masks']
         | 
| 545 | 
            +
                    data_dict['type'] = 'video'
         | 
| 546 | 
            +
                    return data_dict
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                def visualization_debug(self, data_dict):
         | 
| 549 | 
            +
                    save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
         | 
| 550 | 
            +
                    if not os.path.exists(save_folder):
         | 
| 551 | 
            +
                        os.mkdir(save_folder)
         | 
| 552 | 
            +
                    self.cur_number += 1
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                    # images
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    show_images = []
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    pixel_values = data_dict['pixel_values']
         | 
| 559 | 
            +
                    save_folder_image = os.path.join(save_folder, 'image')
         | 
| 560 | 
            +
                    if not os.path.exists(save_folder_image):
         | 
| 561 | 
            +
                        os.mkdir(save_folder_image)
         | 
| 562 | 
            +
                    for i_image, image_pixel_value in enumerate(pixel_values):
         | 
| 563 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 564 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] * 0.2686
         | 
| 565 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] * 0.2613
         | 
| 566 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] * 0.2757
         | 
| 567 | 
            +
                        image_pixel_value[0] = image_pixel_value[0] + 0.4814
         | 
| 568 | 
            +
                        image_pixel_value[1] = image_pixel_value[1] + 0.4578
         | 
| 569 | 
            +
                        image_pixel_value[2] = image_pixel_value[2] + 0.4082
         | 
| 570 | 
            +
                        image_pixel_value = image_pixel_value * 255
         | 
| 571 | 
            +
                        image_pixel_value = image_pixel_value.permute(1, 2, 0)
         | 
| 572 | 
            +
                        image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
         | 
| 573 | 
            +
                        # print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
         | 
| 574 | 
            +
                        # print(image_pixel_value.shape)
         | 
| 575 | 
            +
                        show_images.append(image_pixel_value)
         | 
| 576 | 
            +
                        cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                    # text
         | 
| 579 | 
            +
                    input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
         | 
| 580 | 
            +
                    with open(os.path.join(save_folder, 'text.json'), 'w') as f:
         | 
| 581 | 
            +
                        json.dump([input_text], f)
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                    # masks
         | 
| 584 | 
            +
                    save_folder_mask = os.path.join(save_folder, 'mask')
         | 
| 585 | 
            +
                    if not os.path.exists(save_folder_mask):
         | 
| 586 | 
            +
                        os.mkdir(save_folder_mask)
         | 
| 587 | 
            +
                    n_frames = len(pixel_values)
         | 
| 588 | 
            +
                    masks = data_dict['masks']
         | 
| 589 | 
            +
                    _, h, w = masks.shape
         | 
| 590 | 
            +
                    masks = masks.reshape(-1, n_frames, h, w)
         | 
| 591 | 
            +
                    for i_obj, obj_masks in enumerate(masks):
         | 
| 592 | 
            +
                        save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
         | 
| 593 | 
            +
                        if not os.path.exists(save_folder_mask_obj_folder):
         | 
| 594 | 
            +
                            os.mkdir(save_folder_mask_obj_folder)
         | 
| 595 | 
            +
                        for i_frame, f_mask in enumerate(obj_masks):
         | 
| 596 | 
            +
                            f_mask = f_mask.numpy()
         | 
| 597 | 
            +
                            f_mask = f_mask * 255
         | 
| 598 | 
            +
                            f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
         | 
| 599 | 
            +
                            f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
         | 
| 600 | 
            +
                            f_mask = f_mask.astype(np.uint8)
         | 
| 601 | 
            +
                            cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
         | 
| 602 | 
            +
                    return
         | 
    	
        projects/llava_sam2/datasets/RefCOCO_Dataset.py
    ADDED
    
    | @@ -0,0 +1,338 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from typing import Literal
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from mmengine import print_log
         | 
| 12 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 13 | 
            +
            from PIL import Image
         | 
| 14 | 
            +
            from torch.utils.data import Dataset
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            import torchvision.transforms as T
         | 
| 18 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 19 | 
            +
            from pycocotools.coco import COCO
         | 
| 20 | 
            +
            from pycocotools import mask as mask_utils
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from xtuner.registry import BUILDER
         | 
| 23 | 
            +
            from xtuner.utils import IGNORE_INDEX
         | 
| 24 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 25 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
         | 
| 30 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from third_parts.mmdet.datasets.refcoco import RefCocoDataset
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from .utils import dynamic_preprocess
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class ReferSegmDataset(RefCocoDataset):
         | 
| 38 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 39 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 40 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 41 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 44 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def __init__(self,
         | 
| 47 | 
            +
                             data_root,
         | 
| 48 | 
            +
                             ann_file=None,
         | 
| 49 | 
            +
                             split_file=None,
         | 
| 50 | 
            +
                             special_tokens=None,
         | 
| 51 | 
            +
                             prompt_template=None,
         | 
| 52 | 
            +
                             extra_image_processor=None,
         | 
| 53 | 
            +
                             data_prefix=dict(img_path='train2014/'),
         | 
| 54 | 
            +
                             tokenizer=None,
         | 
| 55 | 
            +
                             max_length=2048,
         | 
| 56 | 
            +
                             num_classes_per_sample=3,
         | 
| 57 | 
            +
                             single_image_mode=False,
         | 
| 58 | 
            +
                             arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
         | 
| 59 | 
            +
                             preprocessor=None,
         | 
| 60 | 
            +
                             **kwargs):
         | 
| 61 | 
            +
                    super().__init__(
         | 
| 62 | 
            +
                        data_root=data_root,
         | 
| 63 | 
            +
                        data_prefix=data_prefix,
         | 
| 64 | 
            +
                        pipeline=None,
         | 
| 65 | 
            +
                        ann_file=ann_file,
         | 
| 66 | 
            +
                        split_file=split_file,
         | 
| 67 | 
            +
                        **kwargs,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    self.begin_str = f'{DEFAULT_IMAGE_TOKEN}\n'
         | 
| 70 | 
            +
                    if extra_image_processor is not None:
         | 
| 71 | 
            +
                        self.extra_image_processor = BUILDER.build(extra_image_processor)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.arch_type = arch_type
         | 
| 74 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 75 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
         | 
| 76 | 
            +
                        self.IMG_START_TOKEN = '<|vision_start|>'
         | 
| 77 | 
            +
                        self.IMG_END_TOKEN = '<|vision_end|>'
         | 
| 78 | 
            +
                    elif self.arch_type == 'llava':
         | 
| 79 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<image>'
         | 
| 80 | 
            +
                        self.IMG_START_TOKEN = ''
         | 
| 81 | 
            +
                        self.IMG_END_TOKEN = ''
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 84 | 
            +
                    if special_tokens is not None:
         | 
| 85 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    self.image_folder = data_root
         | 
| 88 | 
            +
                    self.template = prompt_template
         | 
| 89 | 
            +
                    self.max_length = max_length
         | 
| 90 | 
            +
                    if self.arch_type == 'intern_vl':
         | 
| 91 | 
            +
                        # self._system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
         | 
| 92 | 
            +
                        self._system = ''
         | 
| 93 | 
            +
                        self.template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n'
         | 
| 94 | 
            +
                    elif self.arch_type == 'qwen':
         | 
| 95 | 
            +
                        self._system = ''
         | 
| 96 | 
            +
                    elif self.arch_type == 'llava':
         | 
| 97 | 
            +
                        self._system = ''
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.num_classes_per_sample = num_classes_per_sample
         | 
| 100 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 101 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 102 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 103 | 
            +
                    if self.arch_type == 'llava':
         | 
| 104 | 
            +
                        self.downsample_ratio = 1
         | 
| 105 | 
            +
                    self.image_size = 448
         | 
| 106 | 
            +
                    if self.arch_type == 'llava':
         | 
| 107 | 
            +
                        self.image_size = 336
         | 
| 108 | 
            +
                    self.use_thumbnail = True
         | 
| 109 | 
            +
                    patch_size = 14
         | 
| 110 | 
            +
                    self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if preprocessor is None:
         | 
| 113 | 
            +
                        self.transformer = T.Compose([
         | 
| 114 | 
            +
                            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 115 | 
            +
                            T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 116 | 
            +
                            T.ToTensor(),
         | 
| 117 | 
            +
                            T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 118 | 
            +
                        ])
         | 
| 119 | 
            +
                        self.preprocessor = None
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        self.transformer = None
         | 
| 122 | 
            +
                        self.preprocessor = BUILDER.build(preprocessor)
         | 
| 123 | 
            +
                    self.arch_type = arch_type
         | 
| 124 | 
            +
                    self.single_image_mode = single_image_mode
         | 
| 125 | 
            +
                    self._max_refetch = 1000
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    print("Image RES dataset, include {} items.".format(len(self)))
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @property
         | 
| 130 | 
            +
                def modality_length(self):
         | 
| 131 | 
            +
                    import pickle
         | 
| 132 | 
            +
                    length_list = []
         | 
| 133 | 
            +
                    for idx in range(len(self)):
         | 
| 134 | 
            +
                        length_list.append(100)
         | 
| 135 | 
            +
                    return length_list
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def _parse_annotations(self, ann_info):
         | 
| 138 | 
            +
                    image_path = ann_info['img_path']
         | 
| 139 | 
            +
                    image = Image.open(image_path).convert('RGB')
         | 
| 140 | 
            +
                    width, height = image.size
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    masks, phrases = [], []
         | 
| 143 | 
            +
                    instances, text = ann_info['instances'], ann_info['text']
         | 
| 144 | 
            +
                    # index = np.random.choice(range(len(instances)), min(
         | 
| 145 | 
            +
                    #     len(instances), self.num_classes_per_sample))
         | 
| 146 | 
            +
                    index = np.random.choice(range(len(instances)), self.num_classes_per_sample, replace=True)
         | 
| 147 | 
            +
                    for idx in index:
         | 
| 148 | 
            +
                        inst = instances[idx]
         | 
| 149 | 
            +
                        phrase = text[idx].lower()
         | 
| 150 | 
            +
                        if '.' == phrase[-1]:
         | 
| 151 | 
            +
                            phrase = phrase[:-1]
         | 
| 152 | 
            +
                        phrases.append(phrase)
         | 
| 153 | 
            +
                        binary_mask = np.zeros((height, width), dtype=np.uint8)
         | 
| 154 | 
            +
                        for seg in inst["mask"]:
         | 
| 155 | 
            +
                            rles = mask_utils.frPyObjects([seg], height, width)
         | 
| 156 | 
            +
                            m = mask_utils.decode(rles)
         | 
| 157 | 
            +
                            m = m.astype(np.uint8)
         | 
| 158 | 
            +
                            binary_mask += m.squeeze()
         | 
| 159 | 
            +
                        masks.append(binary_mask)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    conversation = []
         | 
| 162 | 
            +
                    for i, phrase in enumerate(phrases):
         | 
| 163 | 
            +
                        question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
         | 
| 164 | 
            +
                        if i == 0:
         | 
| 165 | 
            +
                            question = self.begin_str + question
         | 
| 166 | 
            +
                        conversation.append({'from': 'human', 'value': question})
         | 
| 167 | 
            +
                        conversation.append({'from': 'gpt', 'value': random.choice(ANSWER_LIST)})
         | 
| 168 | 
            +
                    masks = torch.stack([torch.from_numpy(mask) for mask in masks], dim=0)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    ann_info.update({
         | 
| 171 | 
            +
                        'masks': masks,
         | 
| 172 | 
            +
                        'conversations': conversation,
         | 
| 173 | 
            +
                        'image': image_path
         | 
| 174 | 
            +
                    })
         | 
| 175 | 
            +
                    return ann_info
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def prepare_data(self, index):
         | 
| 178 | 
            +
                    data_dict = super().prepare_data(index)
         | 
| 179 | 
            +
                    data_dict = self._parse_annotations(data_dict)
         | 
| 180 | 
            +
                    if data_dict is None:
         | 
| 181 | 
            +
                        return None
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    out_data_dict = {}
         | 
| 184 | 
            +
                    if 'masks' in data_dict:
         | 
| 185 | 
            +
                        out_data_dict['masks'] = data_dict['masks']
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if data_dict.get('image', None) is not None:
         | 
| 188 | 
            +
                        image_file = data_dict['image']
         | 
| 189 | 
            +
                        try:
         | 
| 190 | 
            +
                            image = Image.open(image_file).convert('RGB')
         | 
| 191 | 
            +
                        except Exception as e:
         | 
| 192 | 
            +
                            print(f'Error: {e}', flush=True)
         | 
| 193 | 
            +
                            print_log(f'Error: {e}', logger='current')
         | 
| 194 | 
            +
                            return None
         | 
| 195 | 
            +
                        if hasattr(self, 'extra_image_processor'):
         | 
| 196 | 
            +
                            g_image = np.array(image)  # for grounding
         | 
| 197 | 
            +
                            g_image = self.extra_image_processor.apply_image(g_image)
         | 
| 198 | 
            +
                            g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
         | 
| 199 | 
            +
                            out_data_dict['g_pixel_values'] = g_pixel_values
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        if self.single_image_mode:
         | 
| 202 | 
            +
                            images = [image]
         | 
| 203 | 
            +
                        else:
         | 
| 204 | 
            +
                            images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 205 | 
            +
                                                        self.max_dynamic_patch,
         | 
| 206 | 
            +
                                                        self.image_size, self.use_thumbnail)
         | 
| 207 | 
            +
                        if self.preprocessor is not None:
         | 
| 208 | 
            +
                            if self.arch_type == 'qwen':
         | 
| 209 | 
            +
                                _data_dict = self.preprocessor(images, do_resize=True)
         | 
| 210 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 211 | 
            +
                                _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
         | 
| 212 | 
            +
                                num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
         | 
| 213 | 
            +
                            elif self.arch_type == 'llava':
         | 
| 214 | 
            +
                                _data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
         | 
| 215 | 
            +
                                _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
         | 
| 216 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 217 | 
            +
                                num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
         | 
| 218 | 
            +
                            else:
         | 
| 219 | 
            +
                                raise NotImplementedError
         | 
| 220 | 
            +
                            out_data_dict.update(_data_dict)
         | 
| 221 | 
            +
                        else:
         | 
| 222 | 
            +
                            pixel_values = [self.transformer(image) for image in images]
         | 
| 223 | 
            +
                            pixel_values = torch.stack(pixel_values)
         | 
| 224 | 
            +
                            out_data_dict['pixel_values'] = pixel_values
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                            num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 227 | 
            +
                        image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 228 | 
            +
                                          f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 229 | 
            +
                                          f'{self.IMG_END_TOKEN}'
         | 
| 230 | 
            +
                        token_dict = self.get_inputid_labels(data_dict['conversations'], image_token_str)
         | 
| 231 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 232 | 
            +
                    else:
         | 
| 233 | 
            +
                        token_dict = self.get_inputid_labels(data_dict['conversations'], None)
         | 
| 234 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 235 | 
            +
                        out_data_dict['pixel_values'] = torch.zeros(1, 3, self.image_size, self.image_size)
         | 
| 236 | 
            +
                    return out_data_dict
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def get_inputid_labels(self, conversations, image_token_str) -> dict:
         | 
| 239 | 
            +
                    input = ''
         | 
| 240 | 
            +
                    out_conversation = []
         | 
| 241 | 
            +
                    while conversations and conversations[0]['from'] == 'gpt':
         | 
| 242 | 
            +
                        # Skip the first one if it is from gpt
         | 
| 243 | 
            +
                        conversations = conversations[1:]
         | 
| 244 | 
            +
                    for msg in conversations:
         | 
| 245 | 
            +
                        if msg['from'] == 'human':
         | 
| 246 | 
            +
                            if image_token_str is None and '<image>' in msg['value']:
         | 
| 247 | 
            +
                                msg['value'] = msg['value'].replace('<image>', '')
         | 
| 248 | 
            +
                            if '<image>' in msg['value']:
         | 
| 249 | 
            +
                                msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
         | 
| 250 | 
            +
                            input += msg['value'].strip()
         | 
| 251 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 252 | 
            +
                            out_conversation.append({
         | 
| 253 | 
            +
                                'input': input,
         | 
| 254 | 
            +
                                'output': msg['value'].strip()
         | 
| 255 | 
            +
                            })
         | 
| 256 | 
            +
                            input = ''
         | 
| 257 | 
            +
                        else:
         | 
| 258 | 
            +
                            raise NotImplementedError
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    input_ids, labels = [], []
         | 
| 261 | 
            +
                    for i, single_turn_conversation in enumerate(out_conversation):
         | 
| 262 | 
            +
                        input = single_turn_conversation.get('input', '')
         | 
| 263 | 
            +
                        if input is None:
         | 
| 264 | 
            +
                            input = ''
         | 
| 265 | 
            +
                        input_text = self.template.INSTRUCTION.format(
         | 
| 266 | 
            +
                            input=input, round=i + 1)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                        if i == 0:
         | 
| 269 | 
            +
                            if self._system != '' and self._system is not None:
         | 
| 270 | 
            +
                                system = self.template.SYSTEM.format(system=self._system)
         | 
| 271 | 
            +
                                input_text = system + input_text
         | 
| 272 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 273 | 
            +
                                input_text, add_special_tokens=True)
         | 
| 274 | 
            +
                        else:
         | 
| 275 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 276 | 
            +
                                input_text, add_special_tokens=False)
         | 
| 277 | 
            +
                        input_ids += input_encode
         | 
| 278 | 
            +
                        labels += [IGNORE_INDEX] * len(input_encode)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        output_text = single_turn_conversation.get('output', '')
         | 
| 281 | 
            +
                        if self.template.get('SUFFIX', None):
         | 
| 282 | 
            +
                            output_text += self.template.SUFFIX
         | 
| 283 | 
            +
                        output_encode = self.tokenizer.encode(
         | 
| 284 | 
            +
                            output_text, add_special_tokens=False)
         | 
| 285 | 
            +
                        input_ids += output_encode
         | 
| 286 | 
            +
                        labels += copy.deepcopy(output_encode)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if len(input_ids) > self.max_length:
         | 
| 289 | 
            +
                        input_ids = input_ids[:self.max_length]
         | 
| 290 | 
            +
                        labels = labels[:self.max_length]
         | 
| 291 | 
            +
                    # print('len_ids: ', len(input_ids))
         | 
| 292 | 
            +
                    return {'input_ids': input_ids, 'labels': labels}
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                def __getitem__(self, index):
         | 
| 295 | 
            +
                    for _ in range(self._max_refetch + 1):
         | 
| 296 | 
            +
                        data = self.prepare_data(index)
         | 
| 297 | 
            +
                        # Broken images may cause the returned data to be None
         | 
| 298 | 
            +
                        if data is None:
         | 
| 299 | 
            +
                            index = self._rand_another()
         | 
| 300 | 
            +
                            continue
         | 
| 301 | 
            +
                        return data
         | 
| 302 | 
            +
             | 
| 303 | 
            +
             | 
| 304 | 
            +
            if __name__ == '__main__':
         | 
| 305 | 
            +
                from transformers import CLIPImageProcessor, AutoTokenizer
         | 
| 306 | 
            +
                from third_parts.segment_anything.utils.transforms import ResizeLongestSide
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
         | 
| 309 | 
            +
                llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                tokenizer = dict(
         | 
| 312 | 
            +
                    type=AutoTokenizer.from_pretrained,
         | 
| 313 | 
            +
                    pretrained_model_name_or_path=llm_name_or_path)
         | 
| 314 | 
            +
                image_processor = dict(
         | 
| 315 | 
            +
                    type=CLIPImageProcessor.from_pretrained,
         | 
| 316 | 
            +
                    pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
         | 
| 317 | 
            +
                extra_image_processor = dict(
         | 
| 318 | 
            +
                    type=ResizeLongestSide,
         | 
| 319 | 
            +
                    target_length=1024,
         | 
| 320 | 
            +
                )
         | 
| 321 | 
            +
                from xtuner.utils.templates import PROMPT_TEMPLATE
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                prompt_template = PROMPT_TEMPLATE.vicuna
         | 
| 324 | 
            +
                from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
         | 
| 325 | 
            +
                from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                dataset = ReferSegmDataset(
         | 
| 328 | 
            +
                    tokenizer=tokenizer,
         | 
| 329 | 
            +
                    special_tokens=['[SEG]'],
         | 
| 330 | 
            +
                    extra_image_processor=extra_image_processor,
         | 
| 331 | 
            +
                    prompt_template=prompt_template,
         | 
| 332 | 
            +
                    data_root='data/coco/',
         | 
| 333 | 
            +
                    data_prefix=dict(img_path='train2014/'),
         | 
| 334 | 
            +
                    ann_file='refcoco+/instances.json',
         | 
| 335 | 
            +
                    split_file='refcoco+/refs(unc).p',
         | 
| 336 | 
            +
                )
         | 
| 337 | 
            +
                for i in range(1000):
         | 
| 338 | 
            +
                    dataset[i]
         | 
    	
        projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .ReVOS_Dataset import VideoReVOSDataset
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import pickle
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class VideoRefYoutubeVOSDataset(VideoReVOSDataset):
         | 
| 6 | 
            +
             | 
| 7 | 
            +
                def json_file_preprocess(self, expression_file, mask_file):
         | 
| 8 | 
            +
                    # prepare expression annotation files
         | 
| 9 | 
            +
                    with open(expression_file, 'r') as f:
         | 
| 10 | 
            +
                        expression_datas = json.load(f)['videos']
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                    metas = []
         | 
| 13 | 
            +
                    anno_count = 0  # serve as anno_id
         | 
| 14 | 
            +
                    vid2metaid = {}
         | 
| 15 | 
            +
                    for vid_name in expression_datas:
         | 
| 16 | 
            +
                        vid_express_data = expression_datas[vid_name]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                        vid_frames = sorted(vid_express_data['frames'])
         | 
| 19 | 
            +
                        vid_len = len(vid_frames)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                        exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
         | 
| 22 | 
            +
                        for exp_id in exp_id_list:
         | 
| 23 | 
            +
                            exp_dict = vid_express_data['expressions'][exp_id]
         | 
| 24 | 
            +
                            meta = {}
         | 
| 25 | 
            +
                            meta['video'] = vid_name
         | 
| 26 | 
            +
                            meta['exp'] = exp_dict['exp']  # str
         | 
| 27 | 
            +
                            meta['mask_anno_id'] = [str(anno_count), ]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                            if 'obj_id' in exp_dict.keys():
         | 
| 30 | 
            +
                                meta['obj_id'] = exp_dict['obj_id']
         | 
| 31 | 
            +
                            else:
         | 
| 32 | 
            +
                                meta['obj_id'] = [0, ]  # Ref-Youtube-VOS only has one object per expression
         | 
| 33 | 
            +
                            meta['anno_id'] = [str(anno_count), ]
         | 
| 34 | 
            +
                            anno_count += 1
         | 
| 35 | 
            +
                            meta['frames'] = vid_frames
         | 
| 36 | 
            +
                            meta['exp_id'] = exp_id
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                            meta['length'] = vid_len
         | 
| 39 | 
            +
                            metas.append(meta)
         | 
| 40 | 
            +
                            if vid_name not in vid2metaid.keys():
         | 
| 41 | 
            +
                                vid2metaid[vid_name] = []
         | 
| 42 | 
            +
                            vid2metaid[vid_name].append(len(metas) - 1)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # process mask annotation files
         | 
| 45 | 
            +
                    with open(mask_file, 'rb') as f:
         | 
| 46 | 
            +
                        mask_dict = pickle.load(f)
         | 
| 47 | 
            +
                    return vid2metaid, metas, mask_dict
         | 
    	
        projects/llava_sam2/datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .collect_fns import video_lisa_collate_fn
         | 
| 2 | 
            +
            from .MeVIS_Dataset import VideoMeVISDataset
         | 
| 3 | 
            +
            from .ReVOS_Dataset import VideoReVOSDataset
         | 
| 4 | 
            +
            from .RefYoutubeVOS_Dataset import VideoRefYoutubeVOSDataset
         | 
| 5 | 
            +
            from .encode_fn import video_lisa_encode_fn
         | 
| 6 | 
            +
            from .RefCOCO_Dataset import ReferSegmDataset
         | 
| 7 | 
            +
            from .ReSAM2_Dataset import VideoSAM2Dataset
         | 
| 8 | 
            +
            from .vqa_dataset import LLaVADataset, InfinityMMDataset
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .GCG_Dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset
         | 
| 11 | 
            +
            from .Grand_Dataset import GranDDataset
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from .Osprey_Dataset import OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from .ChatUniVi_Dataset import VideoChatUniViDataset
         | 
    	
        projects/llava_sam2/datasets/collect_fns.py
    ADDED
    
    | @@ -0,0 +1,206 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Dict, Sequence
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from torch.nn.utils.rnn import pad_sequence
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
         | 
| 8 | 
            +
                                                  pad_for_sequence_parallel)
         | 
| 9 | 
            +
            from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def video_lisa_collate_fn(instances: Sequence[Dict],
         | 
| 13 | 
            +
                                   pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
         | 
| 14 | 
            +
                                   return_hf_format: bool = False,
         | 
| 15 | 
            +
                                   use_varlen_attn: bool = False):
         | 
| 16 | 
            +
                seq_parallel_world_size = get_sequence_parallel_world_size()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                input_ids, labels = [], []
         | 
| 19 | 
            +
                has_image = any(inst.get('pixel_values') is not None for inst in instances)
         | 
| 20 | 
            +
                has_pe = any(inst.get('image_grid_thw', None) is not None for inst in instances)
         | 
| 21 | 
            +
                has_fast_image = any(inst.get('fast_pixel_values', None) is not None for inst in instances)
         | 
| 22 | 
            +
                has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
         | 
| 23 | 
            +
                has_mask = any(inst.get('masks') is not None for inst in instances)
         | 
| 24 | 
            +
                has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
         | 
| 25 | 
            +
                has_points = any(inst.get('points') is not None for inst in instances)
         | 
| 26 | 
            +
                has_fast_exists = any(inst.get('fast_exists') is not None for inst in instances)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                has_vp = any(inst.get('vp_overall_mask') is not None for inst in instances)
         | 
| 29 | 
            +
                has_prompt_mask = any(inst.get('prompt_masks') is not None for inst in instances)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                if use_varlen_attn:
         | 
| 32 | 
            +
                    position_ids, cumulative_len = [], []
         | 
| 33 | 
            +
                    assert len(instances) == 1, (
         | 
| 34 | 
            +
                        f'If utilizing varlen attention, the batch size should be'
         | 
| 35 | 
            +
                        f' set to 1, but got {len(instances)}')
         | 
| 36 | 
            +
                    assert not has_image, 'Currently, it is not configured to '
         | 
| 37 | 
            +
                    'accommodate the use of varlen Attention in multimodal training'
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                if has_image:
         | 
| 40 | 
            +
                    pixel_values = []
         | 
| 41 | 
            +
                    frames_per_batch = []
         | 
| 42 | 
            +
                    image_grid_thw = []
         | 
| 43 | 
            +
                if has_grounding_image:
         | 
| 44 | 
            +
                    grounding_pixel_values = []
         | 
| 45 | 
            +
                if has_mask:
         | 
| 46 | 
            +
                    object_masks = []
         | 
| 47 | 
            +
                if has_bboxes:
         | 
| 48 | 
            +
                    object_bboxes = []
         | 
| 49 | 
            +
                if has_points:
         | 
| 50 | 
            +
                    prompt_points = []
         | 
| 51 | 
            +
                if has_fast_image:
         | 
| 52 | 
            +
                    fast_pixel_values = []
         | 
| 53 | 
            +
                if has_fast_exists:
         | 
| 54 | 
            +
                    fast_exists = []
         | 
| 55 | 
            +
                if has_vp:
         | 
| 56 | 
            +
                    vp_overall_mask = []
         | 
| 57 | 
            +
                else:
         | 
| 58 | 
            +
                    vp_overall_mask = None
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                if has_prompt_mask:
         | 
| 61 | 
            +
                    prompt_masks = []
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    prompt_masks = None
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                for example in instances:
         | 
| 66 | 
            +
                    input_ids.append(torch.LongTensor(example['input_ids']))
         | 
| 67 | 
            +
                    labels.append(torch.LongTensor(example['labels']))
         | 
| 68 | 
            +
                    if use_varlen_attn:
         | 
| 69 | 
            +
                        cumulative_len.append(torch.IntTensor(example['cumulative_len']))
         | 
| 70 | 
            +
                        position_ids.append(torch.LongTensor(example['position_ids']))
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    if has_image:
         | 
| 73 | 
            +
                        pixel_values.append(example['pixel_values'])
         | 
| 74 | 
            +
                        if has_pe:
         | 
| 75 | 
            +
                            image_grid_thw.append(example['image_grid_thw'])
         | 
| 76 | 
            +
                        if has_vp:
         | 
| 77 | 
            +
                            if 'vp_overall_mask' in example.keys() and example['vp_overall_mask'] is not None:
         | 
| 78 | 
            +
                                vp_overall_mask.append(example['vp_overall_mask'])
         | 
| 79 | 
            +
                            else:
         | 
| 80 | 
            +
                                vp_overall_mask.append(torch.Tensor([False] * len(pixel_values[-1])))
         | 
| 81 | 
            +
                    if has_fast_image:
         | 
| 82 | 
            +
                        if 'fast_pixel_values' in example.keys() and example['fast_pixel_values'] is not None:
         | 
| 83 | 
            +
                            fast_pixel_values.append(example['fast_pixel_values'])
         | 
| 84 | 
            +
                    if has_fast_exists:
         | 
| 85 | 
            +
                        if 'fast_exists' in example.keys() and example['fast_exists'] is not None:
         | 
| 86 | 
            +
                            fast_exists.append(example['fast_exists'])
         | 
| 87 | 
            +
                    if has_grounding_image and 'g_pixel_values' in example.keys():
         | 
| 88 | 
            +
                        if isinstance(example['g_pixel_values'], list):
         | 
| 89 | 
            +
                            grounding_pixel_values += example['g_pixel_values']
         | 
| 90 | 
            +
                            frames_per_batch.append(len(example['g_pixel_values']))
         | 
| 91 | 
            +
                        else:
         | 
| 92 | 
            +
                            grounding_pixel_values.append(example['g_pixel_values'])
         | 
| 93 | 
            +
                            frames_per_batch.append(1)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if has_mask:
         | 
| 96 | 
            +
                        if 'masks' in example.keys() and example['masks'] is not None:
         | 
| 97 | 
            +
                            if isinstance(example['masks'], list):
         | 
| 98 | 
            +
                                if isinstance(example['masks'][0], np.ndarray):
         | 
| 99 | 
            +
                                    _masks = np.stack(example['masks'], axis=0)
         | 
| 100 | 
            +
                                    _masks = torch.from_numpy(_masks)
         | 
| 101 | 
            +
                                    object_masks.append(_masks)
         | 
| 102 | 
            +
                                else:
         | 
| 103 | 
            +
                                    object_masks.append(torch.stack(example['masks'], dim=0))
         | 
| 104 | 
            +
                            else:
         | 
| 105 | 
            +
                                object_masks.append(example['masks'])
         | 
| 106 | 
            +
                    if has_bboxes:
         | 
| 107 | 
            +
                        if 'bboxes' in example.keys() and example['bboxes'] is not None:
         | 
| 108 | 
            +
                            object_bboxes.append(example['bboxes'])
         | 
| 109 | 
            +
                    if has_points:
         | 
| 110 | 
            +
                        if 'points' in example.keys() and example['points'] is not None:
         | 
| 111 | 
            +
                            prompt_points.append(example['points'])
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if has_prompt_mask:
         | 
| 114 | 
            +
                        if 'prompt_masks' in example.keys():
         | 
| 115 | 
            +
                            prompt_masks.append(example['prompt_masks'])
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                ori_length = [len(ids) for ids in input_ids]
         | 
| 118 | 
            +
                if len(instances) > 1:
         | 
| 119 | 
            +
                    input_ids = pad_sequence(
         | 
| 120 | 
            +
                        input_ids, batch_first=True, padding_value=pad_index)
         | 
| 121 | 
            +
                    labels = pad_sequence(
         | 
| 122 | 
            +
                        labels, batch_first=True, padding_value=IGNORE_INDEX)
         | 
| 123 | 
            +
                else:
         | 
| 124 | 
            +
                    input_ids = torch.stack(input_ids)
         | 
| 125 | 
            +
                    labels = torch.stack(labels)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if use_varlen_attn:
         | 
| 128 | 
            +
                    assert input_ids.size(1) % seq_parallel_world_size == 0
         | 
| 129 | 
            +
                    attention_mask = None
         | 
| 130 | 
            +
                    position_ids = torch.stack(position_ids, dim=0)
         | 
| 131 | 
            +
                else:
         | 
| 132 | 
            +
                    # Some tokenizers have the same eos token and pad token, so input_ids
         | 
| 133 | 
            +
                    # cannot be masked directly based on the pad token id.
         | 
| 134 | 
            +
                    attention_mask = torch.zeros_like(input_ids).bool()
         | 
| 135 | 
            +
                    for i, length in enumerate(ori_length):
         | 
| 136 | 
            +
                        attention_mask[i, :length] = True
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    bs, seq_len = input_ids.shape
         | 
| 139 | 
            +
                    position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                if seq_parallel_world_size > 1:
         | 
| 142 | 
            +
                    input_ids = pad_for_sequence_parallel(input_ids, pad_index)
         | 
| 143 | 
            +
                    labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
         | 
| 144 | 
            +
                    position_ids = pad_for_sequence_parallel(position_ids, 0)
         | 
| 145 | 
            +
                    if attention_mask is not None:
         | 
| 146 | 
            +
                        attention_mask = pad_for_sequence_parallel(attention_mask, 0)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                if use_varlen_attn:
         | 
| 149 | 
            +
                    max_seqlen = (
         | 
| 150 | 
            +
                        cumulative_len[0][1:] -  # noqa: W504
         | 
| 151 | 
            +
                        cumulative_len[0][:-1]).max().item()
         | 
| 152 | 
            +
                    data_dict = {
         | 
| 153 | 
            +
                        'input_ids': input_ids,
         | 
| 154 | 
            +
                        'cumulative_len': cumulative_len,
         | 
| 155 | 
            +
                        'position_ids': position_ids,
         | 
| 156 | 
            +
                        'labels': labels,
         | 
| 157 | 
            +
                        'max_seqlen': max_seqlen
         | 
| 158 | 
            +
                    }
         | 
| 159 | 
            +
                else:
         | 
| 160 | 
            +
                    data_dict = {
         | 
| 161 | 
            +
                        'input_ids': input_ids,
         | 
| 162 | 
            +
                        'attention_mask': attention_mask,
         | 
| 163 | 
            +
                        'position_ids': position_ids,
         | 
| 164 | 
            +
                        'labels': labels
         | 
| 165 | 
            +
                    }
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                if has_image:
         | 
| 168 | 
            +
                    if all(x.shape == pixel_values[0].shape for x in pixel_values):
         | 
| 169 | 
            +
                        pixel_values = torch.stack(pixel_values, dim=0)
         | 
| 170 | 
            +
                    data_dict['frames_per_batch'] = frames_per_batch
         | 
| 171 | 
            +
                    data_dict['pixel_values'] = pixel_values
         | 
| 172 | 
            +
                    if has_pe:
         | 
| 173 | 
            +
                        data_dict['image_grid_thw'] = image_grid_thw
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                if has_fast_image:
         | 
| 176 | 
            +
                    if all(x.shape == fast_pixel_values[0].shape for x in fast_pixel_values):
         | 
| 177 | 
            +
                        fast_pixel_values = torch.stack(fast_pixel_values, dim=0)
         | 
| 178 | 
            +
                    data_dict['fast_pixel_values'] = fast_pixel_values
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                if has_fast_exists:
         | 
| 181 | 
            +
                    data_dict['fast_exists'] = fast_exists
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                if has_vp:
         | 
| 184 | 
            +
                    data_dict['vp_overall_mask'] = torch.cat(vp_overall_mask, dim=0)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                if has_prompt_mask:
         | 
| 187 | 
            +
                    data_dict['prompt_masks'] = prompt_masks
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                if has_grounding_image:
         | 
| 190 | 
            +
                    # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
         | 
| 191 | 
            +
                        # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
         | 
| 192 | 
            +
                    data_dict['g_pixel_values'] = grounding_pixel_values
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                if has_mask:
         | 
| 195 | 
            +
                    data_dict['masks'] = object_masks
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                if has_bboxes:
         | 
| 198 | 
            +
                    data_dict['bboxes'] = object_bboxes
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                if has_points:
         | 
| 201 | 
            +
                    data_dict['points'] = prompt_points
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                if return_hf_format:
         | 
| 204 | 
            +
                    return data_dict
         | 
| 205 | 
            +
                else:
         | 
| 206 | 
            +
                    return {'data': data_dict, 'data_samples': None}
         | 
    	
        projects/llava_sam2/datasets/encode_fn.py
    ADDED
    
    | @@ -0,0 +1,144 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            from xtuner.dataset.utils import get_bos_eos_token_ids
         | 
| 3 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def video_lisa_encode_fn(
         | 
| 6 | 
            +
                    example,
         | 
| 7 | 
            +
                    tokenizer,
         | 
| 8 | 
            +
                    max_length,
         | 
| 9 | 
            +
                    input_ids_with_output=True,
         | 
| 10 | 
            +
                    **kwargs
         | 
| 11 | 
            +
            ):
         | 
| 12 | 
            +
                """We only support the following three scenarios:
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                1. Incremental pretraining dataset.
         | 
| 15 | 
            +
                    example['conversation'] = [
         | 
| 16 | 
            +
                            {
         | 
| 17 | 
            +
                                'input': '',
         | 
| 18 | 
            +
                                'output': '### Human: Can you write xxx'
         | 
| 19 | 
            +
                            }
         | 
| 20 | 
            +
                        ]
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                2. Single-turn conversation dataset.
         | 
| 23 | 
            +
                    example['conversation'] = [
         | 
| 24 | 
            +
                            {
         | 
| 25 | 
            +
                                'input': 'Give three tips for staying healthy.',
         | 
| 26 | 
            +
                                'output': '1.Eat a balanced diet xxx'
         | 
| 27 | 
            +
                            }
         | 
| 28 | 
            +
                        ]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                3. Multi-turn conversation dataset.
         | 
| 31 | 
            +
                    example['conversation'] = [
         | 
| 32 | 
            +
                            {
         | 
| 33 | 
            +
                                'input': 'Give three tips for staying healthy.',
         | 
| 34 | 
            +
                                'output': '1.Eat a balanced diet xxx'
         | 
| 35 | 
            +
                            },
         | 
| 36 | 
            +
                            {
         | 
| 37 | 
            +
                                'input': 'Please expand on the second point.',
         | 
| 38 | 
            +
                                'output': 'Here is an expanded explanation of the xxx'
         | 
| 39 | 
            +
                            }
         | 
| 40 | 
            +
                        ]
         | 
| 41 | 
            +
                """
         | 
| 42 | 
            +
                bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
         | 
| 43 | 
            +
                is_multi_turn_conversation = len(example['conversation']) > 1
         | 
| 44 | 
            +
                if is_multi_turn_conversation:
         | 
| 45 | 
            +
                    assert input_ids_with_output
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                input_ids, labels = [], []
         | 
| 48 | 
            +
                next_needs_bos_token = True
         | 
| 49 | 
            +
                for single_turn_conversation in example['conversation']:
         | 
| 50 | 
            +
                    input = single_turn_conversation['input']
         | 
| 51 | 
            +
                    input_encode = tokenizer.encode(input, add_special_tokens=False)
         | 
| 52 | 
            +
                    if next_needs_bos_token:
         | 
| 53 | 
            +
                        input_ids += bos_token_id
         | 
| 54 | 
            +
                        labels += [IGNORE_INDEX] * len(bos_token_id)
         | 
| 55 | 
            +
                    input_ids += input_encode
         | 
| 56 | 
            +
                    labels += [IGNORE_INDEX] * len(input_encode)
         | 
| 57 | 
            +
                    if input_ids_with_output:
         | 
| 58 | 
            +
                        # Add output
         | 
| 59 | 
            +
                        output_with_loss = single_turn_conversation.get(
         | 
| 60 | 
            +
                            'output_with_loss', True)
         | 
| 61 | 
            +
                        output = single_turn_conversation['output']
         | 
| 62 | 
            +
                        output_encode = tokenizer.encode(output, add_special_tokens=False)
         | 
| 63 | 
            +
                        input_ids += output_encode
         | 
| 64 | 
            +
                        if output_with_loss:
         | 
| 65 | 
            +
                            labels += copy.deepcopy(output_encode)
         | 
| 66 | 
            +
                        else:
         | 
| 67 | 
            +
                            labels += [IGNORE_INDEX] * len(output_encode)
         | 
| 68 | 
            +
                        # Add EOS_TOKEN (with loss)
         | 
| 69 | 
            +
                        if single_turn_conversation.get('need_eos_token', True):
         | 
| 70 | 
            +
                            next_needs_bos_token = True
         | 
| 71 | 
            +
                            input_ids += eos_token_id
         | 
| 72 | 
            +
                            if output_with_loss:
         | 
| 73 | 
            +
                                labels += copy.deepcopy(eos_token_id)
         | 
| 74 | 
            +
                            else:
         | 
| 75 | 
            +
                                labels += [IGNORE_INDEX] * len(eos_token_id)
         | 
| 76 | 
            +
                        else:
         | 
| 77 | 
            +
                            next_needs_bos_token = False
         | 
| 78 | 
            +
                        # Add SEP (without loss)
         | 
| 79 | 
            +
                        sep = single_turn_conversation.get('sep', '')
         | 
| 80 | 
            +
                        if sep != '':
         | 
| 81 | 
            +
                            sep_encode = tokenizer.encode(sep, add_special_tokens=False)
         | 
| 82 | 
            +
                            input_ids += sep_encode
         | 
| 83 | 
            +
                            labels += [IGNORE_INDEX] * len(sep_encode)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                if len(input_ids) > max_length:
         | 
| 86 | 
            +
                    input_ids = input_ids[:max_length]
         | 
| 87 | 
            +
                    labels = labels[:max_length]
         | 
| 88 | 
            +
                return {'input_ids': input_ids, 'labels': labels}
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def video_lisa_encode_multi_conv_fn(
         | 
| 92 | 
            +
                    example,
         | 
| 93 | 
            +
                    tokenizer,
         | 
| 94 | 
            +
                    max_length,
         | 
| 95 | 
            +
                    input_ids_with_output=True
         | 
| 96 | 
            +
            ):
         | 
| 97 | 
            +
                """We only support the following three scenarios:
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                1. Incremental pretraining dataset.
         | 
| 100 | 
            +
                    example['conversation'] = [
         | 
| 101 | 
            +
                            {
         | 
| 102 | 
            +
                                'input': '',
         | 
| 103 | 
            +
                                'output': '### Human: Can you write xxx'
         | 
| 104 | 
            +
                            }
         | 
| 105 | 
            +
                        ]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                2. Single-turn conversation dataset.
         | 
| 108 | 
            +
                    example['conversation'] = [
         | 
| 109 | 
            +
                            {
         | 
| 110 | 
            +
                                'input': 'Give three tips for staying healthy.',
         | 
| 111 | 
            +
                                'output': '1.Eat a balanced diet xxx'
         | 
| 112 | 
            +
                            }
         | 
| 113 | 
            +
                        ]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                3. Multi-turn conversation dataset.
         | 
| 116 | 
            +
                    example['conversation'] = [
         | 
| 117 | 
            +
                            {
         | 
| 118 | 
            +
                                'input': 'Give three tips for staying healthy.',
         | 
| 119 | 
            +
                                'output': '1.Eat a balanced diet xxx'
         | 
| 120 | 
            +
                            },
         | 
| 121 | 
            +
                            {
         | 
| 122 | 
            +
                                'input': 'Please expand on the second point.',
         | 
| 123 | 
            +
                                'output': 'Here is an expanded explanation of the xxx'
         | 
| 124 | 
            +
                            }
         | 
| 125 | 
            +
                        ]
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
         | 
| 128 | 
            +
                assert not input_ids_with_output
         | 
| 129 | 
            +
                input_id_list = []
         | 
| 130 | 
            +
                for conv in example['conversation']:
         | 
| 131 | 
            +
                    input_ids = []
         | 
| 132 | 
            +
                    next_needs_bos_token = True
         | 
| 133 | 
            +
                    for single_turn_conversation in conv:
         | 
| 134 | 
            +
                        input = single_turn_conversation['input']
         | 
| 135 | 
            +
                        input_encode = tokenizer.encode(input, add_special_tokens=False)
         | 
| 136 | 
            +
                        if next_needs_bos_token:
         | 
| 137 | 
            +
                            input_ids += bos_token_id
         | 
| 138 | 
            +
                        input_ids += input_encode
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    if len(input_ids) > max_length:
         | 
| 141 | 
            +
                        input_ids = input_ids[:max_length]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    input_id_list.append(input_ids)
         | 
| 144 | 
            +
                return {'input_ids': input_id_list}
         | 
    	
        projects/llava_sam2/datasets/gcg_process.py
    ADDED
    
    | @@ -0,0 +1,297 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            GCG_QUESTIONS = [
         | 
| 6 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 7 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 8 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 9 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 10 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
         | 
| 11 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 12 | 
            +
            ]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def refcocog_parse_annotations(example):
         | 
| 15 | 
            +
                # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
         | 
| 16 | 
            +
                annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
         | 
| 17 | 
            +
                               'file_name': example['img_file_name'], 'image': example['img_file_name']}
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                orig_caption = example['caption'].strip('"').strip()
         | 
| 20 | 
            +
                annotations['caption'] = orig_caption.lower()
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                for detail in example['refs']:
         | 
| 23 | 
            +
                    phrase = detail['sentence']
         | 
| 24 | 
            +
                    if phrase.lower() in annotations['caption']:
         | 
| 25 | 
            +
                        annotations['labels'].append(phrase)
         | 
| 26 | 
            +
                        index = annotations['caption'].find(phrase)
         | 
| 27 | 
            +
                        end_index = index + len(phrase) if index != -1 else -1
         | 
| 28 | 
            +
                        annotations['tokens_positive'].append([index, end_index])
         | 
| 29 | 
            +
                        # still polygon or rle
         | 
| 30 | 
            +
                        annotations['masks'].append(detail["segmentation"])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # Sort tokens_positive and corresponding lists
         | 
| 33 | 
            +
                tokens_positive = annotations['tokens_positive']
         | 
| 34 | 
            +
                sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0])
         | 
| 35 | 
            +
                annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices]
         | 
| 36 | 
            +
                annotations['masks'] = [annotations['masks'][i] for i in sorted_indices]
         | 
| 37 | 
            +
                annotations['labels'] = [annotations['labels'][i] for i in sorted_indices]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                # Trimming overlapping intervals
         | 
| 40 | 
            +
                for i in range(len(tokens_positive)):
         | 
| 41 | 
            +
                    for j in range(i + 1, len(tokens_positive)):
         | 
| 42 | 
            +
                        # If there is overlap
         | 
| 43 | 
            +
                        if tokens_positive[i][1] >= tokens_positive[j][0]:
         | 
| 44 | 
            +
                            # Modify the end index of phrase i to be one less than the start index of phrase j
         | 
| 45 | 
            +
                            tokens_positive[i][1] = tokens_positive[j][0] - 1
         | 
| 46 | 
            +
                            # Modify the phrases to reflect the change in indices
         | 
| 47 | 
            +
                            annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1]
         | 
| 48 | 
            +
                            break  # Exit inner loop since i was modified
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                return annotations
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            def refcocog_conversation(caption, tokens_positive):
         | 
| 53 | 
            +
                # insert <p> </p> and [seg] to caption and select a question
         | 
| 54 | 
            +
                question = random.choice(GCG_QUESTIONS).strip()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                # Prepare caption with tags
         | 
| 57 | 
            +
                def tag_caption(caption, tokens):
         | 
| 58 | 
            +
                    for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
         | 
| 59 | 
            +
                        caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
         | 
| 60 | 
            +
                    return caption
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                detailed_answer = tag_caption(caption, tokens_positive)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
         | 
| 65 | 
            +
                return conversations
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            def refcocog_preprocess(example):
         | 
| 68 | 
            +
                data_labels = example['labels']
         | 
| 69 | 
            +
                masks = example['masks']
         | 
| 70 | 
            +
                caption = example['caption']
         | 
| 71 | 
            +
                tokens_positive = example['tokens_positive']
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                # Function to sort elements based on the start index of each phrase
         | 
| 74 | 
            +
                def sort_by_start_index(items, order):
         | 
| 75 | 
            +
                    return [items[i] for i in order]
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Sort phrases based on their appearance in the sentence
         | 
| 78 | 
            +
                phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 79 | 
            +
                masks = sort_by_start_index(masks, phrase_order)
         | 
| 80 | 
            +
                data_labels = sort_by_start_index(data_labels, phrase_order)
         | 
| 81 | 
            +
                tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                conversations = refcocog_conversation(caption, tokens_positive)
         | 
| 84 | 
            +
                example['conversations'] = conversations
         | 
| 85 | 
            +
                example['labels'] = data_labels
         | 
| 86 | 
            +
                example['masks'] = masks
         | 
| 87 | 
            +
                example['tokens_positive'] = tokens_positive
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                return example
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            def glamm_refcocog_map_fn(example):
         | 
| 92 | 
            +
                # example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                example = refcocog_parse_annotations(example)
         | 
| 95 | 
            +
                # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                example = refcocog_preprocess(example)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # do llava preprocess
         | 
| 100 | 
            +
                messages = example['conversations']
         | 
| 101 | 
            +
                input = ''
         | 
| 102 | 
            +
                conversation = []
         | 
| 103 | 
            +
                while messages and messages[0]['from'] == 'gpt':
         | 
| 104 | 
            +
                    # Skip the first one if it is from gpt
         | 
| 105 | 
            +
                    messages = messages[1:]
         | 
| 106 | 
            +
                for msg in messages:
         | 
| 107 | 
            +
                    if msg['from'] == 'human':
         | 
| 108 | 
            +
                        if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 109 | 
            +
                            msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 110 | 
            +
                                                                '').strip()
         | 
| 111 | 
            +
                            msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 112 | 
            +
                            msg['value'] = msg['value'].strip()
         | 
| 113 | 
            +
                        input += msg['value']
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    elif msg['from'] == 'gpt':
         | 
| 116 | 
            +
                        conversation.append({'input': input, 'output': msg['value']})
         | 
| 117 | 
            +
                        input = ''
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        raise NotImplementedError
         | 
| 120 | 
            +
                example.update({'conversation': conversation})
         | 
| 121 | 
            +
                return example
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            def grandf_parse_annotations(example):
         | 
| 124 | 
            +
                image_path = example['file_name']
         | 
| 125 | 
            +
                annotations = {
         | 
| 126 | 
            +
                    'labels': [], 'caption': [], 'masks': [],
         | 
| 127 | 
            +
                    'tokens_positive': [], 'file_name': image_path,
         | 
| 128 | 
            +
                    'image': image_path}
         | 
| 129 | 
            +
                annotations['caption'] = example['caption'].strip('"').strip()
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                for word, grounding in example["groundings"].items():
         | 
| 132 | 
            +
                    if grounding is None:
         | 
| 133 | 
            +
                        continue
         | 
| 134 | 
            +
                    annotations['labels'].append(word)
         | 
| 135 | 
            +
                    annotations['tokens_positive'].append(grounding["token_positives"])
         | 
| 136 | 
            +
                    annotations['masks'].append(grounding["rle_masks"])
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                return annotations
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            def grandf_conversation(caption, tokens_positive):
         | 
| 141 | 
            +
                question = random.choice(GCG_QUESTIONS).strip()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                # Prepare caption with tags
         | 
| 144 | 
            +
                def tag_caption(caption, tokens):
         | 
| 145 | 
            +
                    for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
         | 
| 146 | 
            +
                        caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
         | 
| 147 | 
            +
                    return caption
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                detailed_answer = tag_caption(caption, tokens_positive)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
         | 
| 152 | 
            +
                return conversations
         | 
| 153 | 
            +
            def grandf_preprocess(example):
         | 
| 154 | 
            +
                data_labels = example['labels']
         | 
| 155 | 
            +
                masks = example['masks']
         | 
| 156 | 
            +
                caption = example['caption']
         | 
| 157 | 
            +
                tokens_positive = example['tokens_positive']
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # Function to sort elements based on the start index of each phrase
         | 
| 160 | 
            +
                def sort_by_start_index(items, order):
         | 
| 161 | 
            +
                    return [items[i] for i in order]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                # Sort phrases based on their appearance in the sentence
         | 
| 164 | 
            +
                phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 165 | 
            +
                masks = sort_by_start_index(masks, phrase_order)
         | 
| 166 | 
            +
                data_labels = sort_by_start_index(data_labels, phrase_order)
         | 
| 167 | 
            +
                tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                conversations = grandf_conversation(caption, tokens_positive)
         | 
| 170 | 
            +
                example['conversations'] = conversations
         | 
| 171 | 
            +
                example['labels'] = data_labels
         | 
| 172 | 
            +
                example['masks'] = masks
         | 
| 173 | 
            +
                example['tokens_positive'] = tokens_positive
         | 
| 174 | 
            +
                return example
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            def glamm_granf_map_fn(example):
         | 
| 177 | 
            +
                # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
         | 
| 178 | 
            +
                # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
         | 
| 179 | 
            +
                example = grandf_parse_annotations(example)
         | 
| 180 | 
            +
                # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                example = grandf_preprocess(example)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                # do llava preprocess
         | 
| 185 | 
            +
                messages = example['conversations']
         | 
| 186 | 
            +
                input = ''
         | 
| 187 | 
            +
                conversation = []
         | 
| 188 | 
            +
                while messages and messages[0]['from'] == 'gpt':
         | 
| 189 | 
            +
                    # Skip the first one if it is from gpt
         | 
| 190 | 
            +
                    messages = messages[1:]
         | 
| 191 | 
            +
                for msg in messages:
         | 
| 192 | 
            +
                    if msg['from'] == 'human':
         | 
| 193 | 
            +
                        if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 194 | 
            +
                            msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 195 | 
            +
                                                                '').strip()
         | 
| 196 | 
            +
                            msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 197 | 
            +
                            msg['value'] = msg['value'].strip()
         | 
| 198 | 
            +
                        input += msg['value']
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    elif msg['from'] == 'gpt':
         | 
| 201 | 
            +
                        conversation.append({'input': input, 'output': msg['value']})
         | 
| 202 | 
            +
                        input = ''
         | 
| 203 | 
            +
                    else:
         | 
| 204 | 
            +
                        raise NotImplementedError
         | 
| 205 | 
            +
                example.update({'conversation': conversation})
         | 
| 206 | 
            +
                return example
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            glamm_openpsg_map_fn = glamm_granf_map_fn
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            def flickr_parse_annotations(example):
         | 
| 211 | 
            +
                annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [],
         | 
| 212 | 
            +
                               'tokens_positive': [], 'image': example['file_name']}
         | 
| 213 | 
            +
                ann_info = example["ann_info"]
         | 
| 214 | 
            +
                for ann in ann_info:
         | 
| 215 | 
            +
                    if ann.get('ignore', False):
         | 
| 216 | 
            +
                        continue
         | 
| 217 | 
            +
                    x1, y1, w, h = ann['bbox']
         | 
| 218 | 
            +
                    inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0))
         | 
| 219 | 
            +
                    inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0))
         | 
| 220 | 
            +
                    if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
         | 
| 221 | 
            +
                        continue
         | 
| 222 | 
            +
                    bbox = [x1, y1, x1 + w, y1 + h]
         | 
| 223 | 
            +
                    annotations['bboxes'].append(bbox)
         | 
| 224 | 
            +
                    tokens_positive = ann['tokens_positive']
         | 
| 225 | 
            +
                    gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive]
         | 
| 226 | 
            +
                    annotations['labels'].append(gt_label[0])
         | 
| 227 | 
            +
                    annotations['tokens_positive'].append(tokens_positive[0])
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    rle = ann['sam_mask']
         | 
| 230 | 
            +
                    annotations['masks'].append(rle)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                # Convert bounding boxes to numpy arrays
         | 
| 233 | 
            +
                annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
         | 
| 234 | 
            +
                    'bboxes'] else np.zeros((0, 4), dtype=np.float32)
         | 
| 235 | 
            +
                annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[
         | 
| 236 | 
            +
                    'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32)
         | 
| 237 | 
            +
                return annotations
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            def flickr_preprocess(example):
         | 
| 240 | 
            +
                data_labels = example['labels']
         | 
| 241 | 
            +
                masks = example['masks']
         | 
| 242 | 
            +
                caption = example['caption']
         | 
| 243 | 
            +
                tokens_positive = example['tokens_positive']
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                # Function to sort elements based on the start index of each phrase
         | 
| 246 | 
            +
                def sort_by_start_index(items, order):
         | 
| 247 | 
            +
                    return [items[i] for i in order]
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                # Sort phrases based on their appearance in the sentence
         | 
| 250 | 
            +
                phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 251 | 
            +
                masks = sort_by_start_index(masks, phrase_order)
         | 
| 252 | 
            +
                data_labels = sort_by_start_index(data_labels, phrase_order)
         | 
| 253 | 
            +
                tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                conversations = grandf_conversation(caption, tokens_positive)
         | 
| 256 | 
            +
                example['conversations'] = conversations
         | 
| 257 | 
            +
                example['labels'] = data_labels
         | 
| 258 | 
            +
                example['masks'] = masks
         | 
| 259 | 
            +
                example['tokens_positive'] = tokens_positive
         | 
| 260 | 
            +
                return example
         | 
| 261 | 
            +
             | 
| 262 | 
            +
            def glamm_flickr_map_fn(example):
         | 
| 263 | 
            +
                # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
         | 
| 264 | 
            +
                # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                example = flickr_parse_annotations(example)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                example = flickr_preprocess(example)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                # do llava preprocess
         | 
| 271 | 
            +
                messages = example['conversations']
         | 
| 272 | 
            +
                input = ''
         | 
| 273 | 
            +
                conversation = []
         | 
| 274 | 
            +
                while messages and messages[0]['from'] == 'gpt':
         | 
| 275 | 
            +
                    # Skip the first one if it is from gpt
         | 
| 276 | 
            +
                    messages = messages[1:]
         | 
| 277 | 
            +
                for msg in messages:
         | 
| 278 | 
            +
                    if msg['from'] == 'human':
         | 
| 279 | 
            +
                        if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 280 | 
            +
                            msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 281 | 
            +
                                                                '').strip()
         | 
| 282 | 
            +
                            msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 283 | 
            +
                            msg['value'] = msg['value'].strip()
         | 
| 284 | 
            +
                        input += msg['value']
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    elif msg['from'] == 'gpt':
         | 
| 287 | 
            +
                        conversation.append({'input': input, 'output': msg['value']})
         | 
| 288 | 
            +
                        input = ''
         | 
| 289 | 
            +
                    else:
         | 
| 290 | 
            +
                        raise NotImplementedError
         | 
| 291 | 
            +
                example.update({'conversation': conversation})
         | 
| 292 | 
            +
                return example
         | 
| 293 | 
            +
             | 
| 294 | 
            +
             | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
    	
        projects/llava_sam2/datasets/grand_process.py
    ADDED
    
    | @@ -0,0 +1,110 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            from xtuner.utils import DEFAULT_IMAGE_TOKEN
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            GCG_QUESTIONS = [
         | 
| 6 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 7 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 8 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 9 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
         | 
| 10 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
         | 
| 11 | 
            +
                DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
         | 
| 12 | 
            +
            ]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def grand_parse_annotations(example):
         | 
| 15 | 
            +
                annotations = {
         | 
| 16 | 
            +
                    'caption': [], 'masks': [],
         | 
| 17 | 
            +
                    'tokens_positive': [], 'labels': []}
         | 
| 18 | 
            +
                annotations['caption'] = example['dense_caption']['caption'].strip('"').strip()
         | 
| 19 | 
            +
                object_infos = example['dense_caption']['details']
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                all_seg_objects_dict = {}
         | 
| 22 | 
            +
                for seg_object_dict in example["objects"]:
         | 
| 23 | 
            +
                    all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
         | 
| 24 | 
            +
                for seg_object_dict in example["floating_objects"]:
         | 
| 25 | 
            +
                    all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                for object_info in object_infos:
         | 
| 28 | 
            +
                    ids = object_info["ids"]
         | 
| 29 | 
            +
                    if object_info["tokens_positive"] is None:
         | 
| 30 | 
            +
                        continue
         | 
| 31 | 
            +
                    annotations['labels'].append(object_info["phrase"])
         | 
| 32 | 
            +
                    annotations['tokens_positive'].append(object_info["tokens_positive"])
         | 
| 33 | 
            +
                    _masks = []
         | 
| 34 | 
            +
                    for _id in ids:
         | 
| 35 | 
            +
                        _masks.append(all_seg_objects_dict[_id]['segmentation'])
         | 
| 36 | 
            +
                    annotations['masks'].append(_masks)
         | 
| 37 | 
            +
                return annotations
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def grand_conversation(caption, tokens_positive):
         | 
| 40 | 
            +
                question = random.choice(GCG_QUESTIONS).strip()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                # Prepare caption with tags
         | 
| 43 | 
            +
                def tag_caption(caption, tokens):
         | 
| 44 | 
            +
                    for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
         | 
| 45 | 
            +
                        caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
         | 
| 46 | 
            +
                    return caption
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                detailed_answer = tag_caption(caption, tokens_positive)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
         | 
| 51 | 
            +
                return conversations
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            def grand_preprocess(example):
         | 
| 54 | 
            +
                data_labels = example['labels']
         | 
| 55 | 
            +
                masks = example['masks']
         | 
| 56 | 
            +
                caption = example['caption']
         | 
| 57 | 
            +
                tokens_positive = example['tokens_positive']
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Function to sort elements based on the start index of each phrase
         | 
| 60 | 
            +
                def sort_by_start_index(items, order):
         | 
| 61 | 
            +
                    return [items[i] for i in order]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                # Sort phrases based on their appearance in the sentence
         | 
| 64 | 
            +
                phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
         | 
| 65 | 
            +
                masks = sort_by_start_index(masks, phrase_order)
         | 
| 66 | 
            +
                data_labels = sort_by_start_index(data_labels, phrase_order)
         | 
| 67 | 
            +
                tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                conversations = grand_conversation(caption, tokens_positive)
         | 
| 70 | 
            +
                example['conversations'] = conversations
         | 
| 71 | 
            +
                example['labels'] = data_labels
         | 
| 72 | 
            +
                example['masks'] = masks
         | 
| 73 | 
            +
                example['tokens_positive'] = tokens_positive
         | 
| 74 | 
            +
                return example
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def glamm_grand_map_fn(example):
         | 
| 77 | 
            +
                # example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
         | 
| 78 | 
            +
                # "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
         | 
| 79 | 
            +
                example = grand_parse_annotations(example)
         | 
| 80 | 
            +
                # example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                example = grand_preprocess(example)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # do llava preprocess
         | 
| 85 | 
            +
                messages = example['conversations']
         | 
| 86 | 
            +
                input = ''
         | 
| 87 | 
            +
                conversation = []
         | 
| 88 | 
            +
                while messages and messages[0]['from'] == 'gpt':
         | 
| 89 | 
            +
                    # Skip the first one if it is from gpt
         | 
| 90 | 
            +
                    messages = messages[1:]
         | 
| 91 | 
            +
                for msg in messages:
         | 
| 92 | 
            +
                    if msg['from'] == 'human':
         | 
| 93 | 
            +
                        if DEFAULT_IMAGE_TOKEN in msg['value']:
         | 
| 94 | 
            +
                            msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
         | 
| 95 | 
            +
                                                                '').strip()
         | 
| 96 | 
            +
                            msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
         | 
| 97 | 
            +
                            msg['value'] = msg['value'].strip()
         | 
| 98 | 
            +
                        input += msg['value']
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    elif msg['from'] == 'gpt':
         | 
| 101 | 
            +
                        conversation.append({'input': input, 'output': msg['value']})
         | 
| 102 | 
            +
                        input = ''
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        raise NotImplementedError
         | 
| 105 | 
            +
                example.update({'conversation': conversation})
         | 
| 106 | 
            +
                return example
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
    	
        projects/llava_sam2/datasets/utils.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
         | 
| 3 | 
            +
                                          image_size):
         | 
| 4 | 
            +
                best_ratio_diff = float('inf')
         | 
| 5 | 
            +
                best_ratio = (1, 1)
         | 
| 6 | 
            +
                area = width * height
         | 
| 7 | 
            +
                for ratio in target_ratios:
         | 
| 8 | 
            +
                    target_aspect_ratio = ratio[0] / ratio[1]
         | 
| 9 | 
            +
                    ratio_diff = abs(aspect_ratio - target_aspect_ratio)
         | 
| 10 | 
            +
                    if ratio_diff < best_ratio_diff:
         | 
| 11 | 
            +
                        best_ratio_diff = ratio_diff
         | 
| 12 | 
            +
                        best_ratio = ratio
         | 
| 13 | 
            +
                    elif ratio_diff == best_ratio_diff:
         | 
| 14 | 
            +
                        if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
         | 
| 15 | 
            +
                            best_ratio = ratio
         | 
| 16 | 
            +
                return best_ratio
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def dynamic_preprocess(image,
         | 
| 19 | 
            +
                                   min_num=1,
         | 
| 20 | 
            +
                                   max_num=6,
         | 
| 21 | 
            +
                                   image_size=448,
         | 
| 22 | 
            +
                                   use_thumbnail=False):
         | 
| 23 | 
            +
                orig_width, orig_height = image.size
         | 
| 24 | 
            +
                aspect_ratio = orig_width / orig_height
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                # calculate the existing image aspect ratio
         | 
| 27 | 
            +
                target_ratios = {(i, j)
         | 
| 28 | 
            +
                                 for n in range(min_num, max_num + 1)
         | 
| 29 | 
            +
                                 for i in range(1, n + 1) for j in range(1, n + 1)
         | 
| 30 | 
            +
                                 if i * j <= max_num and i * j >= min_num}
         | 
| 31 | 
            +
                target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                # find the closest aspect ratio to the target
         | 
| 34 | 
            +
                target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
         | 
| 35 | 
            +
                                                                target_ratios, orig_width,
         | 
| 36 | 
            +
                                                                orig_height, image_size)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # calculate the target width and height
         | 
| 39 | 
            +
                target_width = image_size * target_aspect_ratio[0]
         | 
| 40 | 
            +
                target_height = image_size * target_aspect_ratio[1]
         | 
| 41 | 
            +
                blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # resize the image
         | 
| 44 | 
            +
                resized_img = image.resize((target_width, target_height))
         | 
| 45 | 
            +
                processed_images = []
         | 
| 46 | 
            +
                for i in range(blocks):
         | 
| 47 | 
            +
                    box = ((i % (target_width // image_size)) * image_size,
         | 
| 48 | 
            +
                           (i // (target_width // image_size)) * image_size,
         | 
| 49 | 
            +
                           ((i % (target_width // image_size)) + 1) * image_size,
         | 
| 50 | 
            +
                           ((i // (target_width // image_size)) + 1) * image_size)
         | 
| 51 | 
            +
                    # split the image
         | 
| 52 | 
            +
                    split_img = resized_img.crop(box)
         | 
| 53 | 
            +
                    processed_images.append(split_img)
         | 
| 54 | 
            +
                assert len(processed_images) == blocks
         | 
| 55 | 
            +
                if use_thumbnail and len(processed_images) != 1:
         | 
| 56 | 
            +
                    thumbnail_img = image.resize((image_size, image_size))
         | 
| 57 | 
            +
                    processed_images.append(thumbnail_img)
         | 
| 58 | 
            +
                return processed_images
         | 
    	
        projects/llava_sam2/datasets/vqa_dataset.py
    ADDED
    
    | @@ -0,0 +1,509 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
            import glob
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import logging
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from typing import Literal
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from mmengine import print_log
         | 
| 12 | 
            +
            from mmengine.config import Config, ConfigDict
         | 
| 13 | 
            +
            from PIL import Image
         | 
| 14 | 
            +
            from torch.utils.data import Dataset
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            import torch.nn.functional as F
         | 
| 17 | 
            +
            import torchvision.transforms as T
         | 
| 18 | 
            +
            from torchvision.transforms.functional import InterpolationMode
         | 
| 19 | 
            +
            from pycocotools.coco import COCO
         | 
| 20 | 
            +
            from pycocotools import mask as mask_utils
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from xtuner.registry import BUILDER
         | 
| 23 | 
            +
            from xtuner.utils import IGNORE_INDEX
         | 
| 24 | 
            +
            from xtuner.dataset.utils import encode_fn
         | 
| 25 | 
            +
            from xtuner.dataset.map_fns import llava_map_fn
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from projects.glamm.datasets.utils.utils import expand2square
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
         | 
| 30 | 
            +
            from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from .utils import dynamic_preprocess
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            class InfinityMMDataset(Dataset):
         | 
| 36 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 37 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 38 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 39 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 42 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(self,
         | 
| 45 | 
            +
                             tokenizer,
         | 
| 46 | 
            +
                             data_path,
         | 
| 47 | 
            +
                             prompt_template,
         | 
| 48 | 
            +
                             special_tokens=None,
         | 
| 49 | 
            +
                             max_length=8192,
         | 
| 50 | 
            +
                             offline_save_path='./work_dirs/infinityMM.json',
         | 
| 51 | 
            +
                             ):
         | 
| 52 | 
            +
                    self.offline_save_path = offline_save_path
         | 
| 53 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 54 | 
            +
                    if special_tokens is not None:
         | 
| 55 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 56 | 
            +
                    self._system = ''
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.template = prompt_template
         | 
| 59 | 
            +
                    self.max_length = max_length
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 62 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 63 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 64 | 
            +
                    self.image_size = 448
         | 
| 65 | 
            +
                    self.use_thumbnail = True
         | 
| 66 | 
            +
                    patch_size = 14
         | 
| 67 | 
            +
                    self.patch_token = int(
         | 
| 68 | 
            +
                        (self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.transformer = T.Compose([
         | 
| 71 | 
            +
                        T.Lambda(lambda img: img.convert('RGB')
         | 
| 72 | 
            +
                        if img.mode != 'RGB' else img),
         | 
| 73 | 
            +
                        T.Resize((self.image_size, self.image_size),
         | 
| 74 | 
            +
                                 interpolation=InterpolationMode.BICUBIC),
         | 
| 75 | 
            +
                        T.ToTensor(),
         | 
| 76 | 
            +
                        T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 77 | 
            +
                    ])
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.data = self._load_annotations(data_path)
         | 
| 80 | 
            +
                    self._max_refetch = 1000
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def _load_annotations(self, data_path):
         | 
| 83 | 
            +
                    if os.path.exists(self.offline_save_path):
         | 
| 84 | 
            +
                        with open(self.offline_save_path, 'r') as f:
         | 
| 85 | 
            +
                            ret = json.load(f)
         | 
| 86 | 
            +
                        print(f"Load InfinityMM file list from {self.offline_save_path}, {len(ret)} items !!!")
         | 
| 87 | 
            +
                        return ret
         | 
| 88 | 
            +
                    sub_folders = []
         | 
| 89 | 
            +
                    for sub_folder in os.listdir(data_path):
         | 
| 90 | 
            +
                        if '.' not in sub_folder:
         | 
| 91 | 
            +
                            # a folder
         | 
| 92 | 
            +
                            if "LVIS_111k" in sub_folder:
         | 
| 93 | 
            +
                                # special case, have subsub folder
         | 
| 94 | 
            +
                                subsub_folders = os.listdir(os.path.join(data_path, sub_folder))
         | 
| 95 | 
            +
                                for subsub_folder in subsub_folders:
         | 
| 96 | 
            +
                                    sub_folders.append(os.path.join(data_path, sub_folder, subsub_folder))
         | 
| 97 | 
            +
                            else:
         | 
| 98 | 
            +
                                sub_folders.append(os.path.join(data_path, sub_folder))
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    all_jsons = []
         | 
| 101 | 
            +
                    for sub_folder in sub_folders:
         | 
| 102 | 
            +
                        print(f"Processing {sub_folder} !!!")
         | 
| 103 | 
            +
                        _files = os.listdir(sub_folder)
         | 
| 104 | 
            +
                        _num = 0
         | 
| 105 | 
            +
                        for _file in _files:
         | 
| 106 | 
            +
                            if '.json' in _file:
         | 
| 107 | 
            +
                                _json_path = os.path.join(sub_folder, _file)
         | 
| 108 | 
            +
                                _num += 1
         | 
| 109 | 
            +
                                all_jsons.append(os.path.join(sub_folder, _file))
         | 
| 110 | 
            +
                        print(f"Finished {sub_folder} has {_num} items.")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    with open(self.offline_save_path, 'w') as f:
         | 
| 113 | 
            +
                        json.dump(all_jsons, f)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    return all_jsons
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def __getitem__(self, index):
         | 
| 118 | 
            +
                    for _ in range(self._max_refetch + 1):
         | 
| 119 | 
            +
                        data = self.prepare_data(index)
         | 
| 120 | 
            +
                        # Broken images may cause the returned data to be None
         | 
| 121 | 
            +
                        if data is None:
         | 
| 122 | 
            +
                            index = self._rand_another()
         | 
| 123 | 
            +
                            continue
         | 
| 124 | 
            +
                        return data
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def __len__(self):
         | 
| 127 | 
            +
                    return len(self.data)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @property
         | 
| 130 | 
            +
                def modality_length(self):
         | 
| 131 | 
            +
                    self.group_length = []
         | 
| 132 | 
            +
                    for data_dict in self.data:
         | 
| 133 | 
            +
                        self.group_length.append(100)
         | 
| 134 | 
            +
                    return self.group_length
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                @property
         | 
| 137 | 
            +
                def length(self):
         | 
| 138 | 
            +
                    group_length = np.array(self.group_length)
         | 
| 139 | 
            +
                    group_length = np.abs(group_length).tolist()
         | 
| 140 | 
            +
                    return group_length
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def prepare_data(self, index):
         | 
| 143 | 
            +
                    data_path = self.data[index]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    with open(data_path, 'r') as f:
         | 
| 146 | 
            +
                        data_dict = json.load(f)
         | 
| 147 | 
            +
                    if 'image' in data_dict.keys():
         | 
| 148 | 
            +
                        data_dict['image'] = data_path.replace('.json', '.jpg')
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if data_dict is None:
         | 
| 151 | 
            +
                        return None
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    out_data_dict = {}
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    if data_dict.get('image', None) is not None:
         | 
| 156 | 
            +
                        image_file = data_dict['image']
         | 
| 157 | 
            +
                        try:
         | 
| 158 | 
            +
                            image = Image.open(image_file).convert('RGB')
         | 
| 159 | 
            +
                        except Exception as e:
         | 
| 160 | 
            +
                            print(f'Error: {e}', flush=True)
         | 
| 161 | 
            +
                            print_log(f'Error: {e}', logger='current')
         | 
| 162 | 
            +
                            return None
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 165 | 
            +
                                                    self.max_dynamic_patch,
         | 
| 166 | 
            +
                                                    self.image_size, self.use_thumbnail)
         | 
| 167 | 
            +
                        pixel_values = [self.transformer(image) for image in images]
         | 
| 168 | 
            +
                        pixel_values = torch.stack(pixel_values)
         | 
| 169 | 
            +
                        out_data_dict['pixel_values'] = pixel_values
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                        num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 172 | 
            +
                        image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 173 | 
            +
                                          f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 174 | 
            +
                                          f'{self.IMG_END_TOKEN}'
         | 
| 175 | 
            +
                        token_dict = self.get_inputid_labels(
         | 
| 176 | 
            +
                            data_dict['conversations'], image_token_str)
         | 
| 177 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        token_dict = self.get_inputid_labels(
         | 
| 180 | 
            +
                            data_dict['conversations'], None)
         | 
| 181 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 182 | 
            +
                        out_data_dict['pixel_values'] = torch.zeros(
         | 
| 183 | 
            +
                            1, 3, self.image_size, self.image_size)
         | 
| 184 | 
            +
                    return out_data_dict
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def _rand_another(self) -> int:
         | 
| 187 | 
            +
                    return np.random.randint(0, len(self.data))
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def get_inputid_labels(self, conversations, image_token_str) -> dict:
         | 
| 190 | 
            +
                    input = ''
         | 
| 191 | 
            +
                    out_conversation = []
         | 
| 192 | 
            +
                    while conversations and conversations[0]['from'] == 'gpt':
         | 
| 193 | 
            +
                        # Skip the first one if it is from gpt
         | 
| 194 | 
            +
                        conversations = conversations[1:]
         | 
| 195 | 
            +
                    for i, msg in enumerate(conversations):
         | 
| 196 | 
            +
                        if msg['from'] == 'human':
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                            # change to 1 image
         | 
| 199 | 
            +
                            if '<image>' in msg['value']:
         | 
| 200 | 
            +
                                msg['value'] = msg['value'].replace('<image>\n', '').replace('<image>', '')
         | 
| 201 | 
            +
                                if i == 0:
         | 
| 202 | 
            +
                                    msg['value'] = "<image>\n" + msg['value']
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                            if image_token_str is None and '<image>' in msg['value']:
         | 
| 205 | 
            +
                                msg['value'] = msg['value'].replace('<image>', '')
         | 
| 206 | 
            +
                            if '<image>' in msg['value']:
         | 
| 207 | 
            +
                                msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
         | 
| 208 | 
            +
                            input += msg['value'].strip()
         | 
| 209 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 210 | 
            +
                            out_conversation.append({
         | 
| 211 | 
            +
                                'input': input,
         | 
| 212 | 
            +
                                'output': msg['value'].strip()
         | 
| 213 | 
            +
                            })
         | 
| 214 | 
            +
                            input = ''
         | 
| 215 | 
            +
                        else:
         | 
| 216 | 
            +
                            raise NotImplementedError
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    input_ids, labels = [], []
         | 
| 219 | 
            +
                    for i, single_turn_conversation in enumerate(out_conversation):
         | 
| 220 | 
            +
                        input = single_turn_conversation.get('input', '')
         | 
| 221 | 
            +
                        if input is None:
         | 
| 222 | 
            +
                            input = ''
         | 
| 223 | 
            +
                        input_text = self.template.INSTRUCTION.format(
         | 
| 224 | 
            +
                            input=input, round=i + 1)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                        if i == 0:
         | 
| 227 | 
            +
                            if self._system != '' and self._system is not None:
         | 
| 228 | 
            +
                                system = self.template.SYSTEM.format(system=self._system)
         | 
| 229 | 
            +
                                input_text = system + input_text
         | 
| 230 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 231 | 
            +
                                input_text, add_special_tokens=True)
         | 
| 232 | 
            +
                        else:
         | 
| 233 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 234 | 
            +
                                input_text, add_special_tokens=False)
         | 
| 235 | 
            +
                        input_ids += input_encode
         | 
| 236 | 
            +
                        labels += [IGNORE_INDEX] * len(input_encode)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                        output_text = single_turn_conversation.get('output', '')
         | 
| 239 | 
            +
                        if self.template.get('SUFFIX', None):
         | 
| 240 | 
            +
                            output_text += self.template.SUFFIX
         | 
| 241 | 
            +
                        output_encode = self.tokenizer.encode(
         | 
| 242 | 
            +
                            output_text, add_special_tokens=False)
         | 
| 243 | 
            +
                        input_ids += output_encode
         | 
| 244 | 
            +
                        labels += copy.deepcopy(output_encode)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    if len(input_ids) > self.max_length:
         | 
| 247 | 
            +
                        input_ids = input_ids[:self.max_length]
         | 
| 248 | 
            +
                        labels = labels[:self.max_length]
         | 
| 249 | 
            +
                        print_log(
         | 
| 250 | 
            +
                            f'Warning: input_ids length({len(input_ids)}) '
         | 
| 251 | 
            +
                            f'is longer than max_length, cut to {self.max_length}',
         | 
| 252 | 
            +
                            logger='current')
         | 
| 253 | 
            +
                    return {'input_ids': input_ids, 'labels': labels}
         | 
| 254 | 
            +
             | 
| 255 | 
            +
             | 
| 256 | 
            +
            class LLaVADataset(Dataset):
         | 
| 257 | 
            +
                os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 258 | 
            +
                IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
         | 
| 259 | 
            +
                IMG_START_TOKEN = '<img>'
         | 
| 260 | 
            +
                IMG_END_TOKEN = '</img>'
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                IMAGENET_MEAN = (0.485, 0.456, 0.406)
         | 
| 263 | 
            +
                IMAGENET_STD = (0.229, 0.224, 0.225)
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def __init__(self,
         | 
| 266 | 
            +
                             tokenizer,
         | 
| 267 | 
            +
                             data_path,
         | 
| 268 | 
            +
                             prompt_template,
         | 
| 269 | 
            +
                             special_tokens=None,
         | 
| 270 | 
            +
                             image_folder=None,
         | 
| 271 | 
            +
                             max_length=8192,
         | 
| 272 | 
            +
                             arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
         | 
| 273 | 
            +
                             preprocessor=None,
         | 
| 274 | 
            +
                             skip_pure_text=False,
         | 
| 275 | 
            +
                             ):
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    self.tokenizer = BUILDER.build(tokenizer)
         | 
| 278 | 
            +
                    if special_tokens is not None:
         | 
| 279 | 
            +
                        self.tokenizer.add_tokens(special_tokens, special_tokens=True)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    self.image_folder = image_folder
         | 
| 282 | 
            +
                    self.template = prompt_template
         | 
| 283 | 
            +
                    self.max_length = max_length
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    self._system = ''
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    self.arch_type = arch_type
         | 
| 288 | 
            +
                    self.min_dynamic_patch = 1
         | 
| 289 | 
            +
                    self.max_dynamic_patch = 12
         | 
| 290 | 
            +
                    self.downsample_ratio = 0.5
         | 
| 291 | 
            +
                    if self.arch_type == 'llava':
         | 
| 292 | 
            +
                        self.downsample_ratio = 1
         | 
| 293 | 
            +
                    self.image_size = 448
         | 
| 294 | 
            +
                    if self.arch_type == 'llava':
         | 
| 295 | 
            +
                        self.image_size = 336
         | 
| 296 | 
            +
                    self.use_thumbnail = True
         | 
| 297 | 
            +
                    patch_size = 14
         | 
| 298 | 
            +
                    self.patch_token = int(
         | 
| 299 | 
            +
                        (self.image_size // patch_size)**2 * (self.downsample_ratio**2))
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
                    if self.arch_type == 'qwen':
         | 
| 303 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
         | 
| 304 | 
            +
                        self.IMG_START_TOKEN = '<|vision_start|>'
         | 
| 305 | 
            +
                        self.IMG_END_TOKEN = '<|vision_end|>'
         | 
| 306 | 
            +
                    elif self.arch_type == 'llava':
         | 
| 307 | 
            +
                        self.IMG_CONTEXT_TOKEN = '<image>'
         | 
| 308 | 
            +
                        self.IMG_START_TOKEN = ''
         | 
| 309 | 
            +
                        self.IMG_END_TOKEN = ''
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if preprocessor is None:
         | 
| 312 | 
            +
                        self.transformer = T.Compose([
         | 
| 313 | 
            +
                            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
         | 
| 314 | 
            +
                            T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
         | 
| 315 | 
            +
                            T.ToTensor(),
         | 
| 316 | 
            +
                            T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
         | 
| 317 | 
            +
                        ])
         | 
| 318 | 
            +
                        self.preprocessor = None
         | 
| 319 | 
            +
                    else:
         | 
| 320 | 
            +
                        self.transformer = None
         | 
| 321 | 
            +
                        self.preprocessor = BUILDER.build(preprocessor)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    self.data = self._load_annotations(data_path, image_folder)
         | 
| 324 | 
            +
                    self._max_refetch = 1000
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                    self.skip_pure_text = skip_pure_text
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                def _load_annotations(self, data_path, image_folder=None):
         | 
| 329 | 
            +
                    data = json.load(open(data_path))
         | 
| 330 | 
            +
                    return data
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                def __getitem__(self, index):
         | 
| 333 | 
            +
                    for _ in range(self._max_refetch + 1):
         | 
| 334 | 
            +
                        data = self.prepare_data(index)
         | 
| 335 | 
            +
                        # Broken images may cause the returned data to be None
         | 
| 336 | 
            +
                        if data is None:
         | 
| 337 | 
            +
                            index = self._rand_another()
         | 
| 338 | 
            +
                            continue
         | 
| 339 | 
            +
                        return data
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def __len__(self):
         | 
| 342 | 
            +
                    return len(self.data)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                @property
         | 
| 345 | 
            +
                def modality_length(self):
         | 
| 346 | 
            +
                    self.group_length = []
         | 
| 347 | 
            +
                    for data_dict in self.data:
         | 
| 348 | 
            +
                        self.group_length.append(100)
         | 
| 349 | 
            +
                    return self.group_length
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                @property
         | 
| 352 | 
            +
                def length(self):
         | 
| 353 | 
            +
                    group_length = np.array(self.group_length)
         | 
| 354 | 
            +
                    group_length = np.abs(group_length).tolist()
         | 
| 355 | 
            +
                    return group_length
         | 
| 356 | 
            +
                
         | 
| 357 | 
            +
                def prepare_data(self, index):
         | 
| 358 | 
            +
                    data_dict: dict = self.data[index]
         | 
| 359 | 
            +
                    
         | 
| 360 | 
            +
                    if data_dict is None:
         | 
| 361 | 
            +
                        return None
         | 
| 362 | 
            +
                    
         | 
| 363 | 
            +
                    out_data_dict = {}
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    if self.skip_pure_text and data_dict.get('image', None) is None:
         | 
| 366 | 
            +
                        return None
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    if data_dict.get('image', None) is not None:
         | 
| 369 | 
            +
                        image_file = os.path.join(self.image_folder, data_dict['image'])
         | 
| 370 | 
            +
                        try:
         | 
| 371 | 
            +
                            image = Image.open(image_file).convert('RGB')
         | 
| 372 | 
            +
                        except Exception as e:
         | 
| 373 | 
            +
                            print(f'Error: {e}', flush=True)
         | 
| 374 | 
            +
                            print_log(f'Error: {e}', logger='current')
         | 
| 375 | 
            +
                            return None
         | 
| 376 | 
            +
                        if self.preprocessor is not None:
         | 
| 377 | 
            +
                            # images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 378 | 
            +
                            #                             self.max_dynamic_patch,
         | 
| 379 | 
            +
                            #                             self.image_size, self.use_thumbnail)
         | 
| 380 | 
            +
                            images = [image]
         | 
| 381 | 
            +
                            if self.arch_type == 'qwen':
         | 
| 382 | 
            +
                                _data_dict = self.preprocessor(images, do_resize=True)
         | 
| 383 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 384 | 
            +
                                _data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
         | 
| 385 | 
            +
                                num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
         | 
| 386 | 
            +
                            elif self.arch_type == 'llava':
         | 
| 387 | 
            +
                                _data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
         | 
| 388 | 
            +
                                _data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
         | 
| 389 | 
            +
                                _data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
         | 
| 390 | 
            +
                                num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
         | 
| 391 | 
            +
                            else:
         | 
| 392 | 
            +
                                raise NotImplementedError
         | 
| 393 | 
            +
                            out_data_dict.update(_data_dict)
         | 
| 394 | 
            +
                        else:
         | 
| 395 | 
            +
                            images = dynamic_preprocess(image, self.min_dynamic_patch,
         | 
| 396 | 
            +
                                                        self.max_dynamic_patch,
         | 
| 397 | 
            +
                                                        self.image_size, self.use_thumbnail)
         | 
| 398 | 
            +
                            pixel_values = [self.transformer(image) for image in images]
         | 
| 399 | 
            +
                            pixel_values = torch.stack(pixel_values)
         | 
| 400 | 
            +
                            out_data_dict['pixel_values'] = pixel_values
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                            num_image_tokens = pixel_values.shape[0] * self.patch_token
         | 
| 403 | 
            +
                        image_token_str = f'{self.IMG_START_TOKEN}' \
         | 
| 404 | 
            +
                                          f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
         | 
| 405 | 
            +
                                          f'{self.IMG_END_TOKEN}'
         | 
| 406 | 
            +
                        token_dict = self.get_inputid_labels(
         | 
| 407 | 
            +
                            data_dict['conversations'], image_token_str)
         | 
| 408 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 409 | 
            +
                    else:
         | 
| 410 | 
            +
                        token_dict = self.get_inputid_labels(
         | 
| 411 | 
            +
                            data_dict['conversations'], None)
         | 
| 412 | 
            +
                        out_data_dict.update(token_dict)
         | 
| 413 | 
            +
                        out_data_dict['pixel_values'] = torch.zeros(
         | 
| 414 | 
            +
                            1, 3, self.image_size, self.image_size)
         | 
| 415 | 
            +
                    return out_data_dict
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                def _rand_another(self) -> int:
         | 
| 418 | 
            +
                    return np.random.randint(0, len(self.data))
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                def get_inputid_labels(self, conversations, image_token_str) -> dict:
         | 
| 421 | 
            +
                    input = ''
         | 
| 422 | 
            +
                    out_conversation = []
         | 
| 423 | 
            +
                    while conversations and conversations[0]['from'] == 'gpt':
         | 
| 424 | 
            +
                        # Skip the first one if it is from gpt
         | 
| 425 | 
            +
                        conversations = conversations[1:]
         | 
| 426 | 
            +
                    for msg in conversations:
         | 
| 427 | 
            +
                        if msg['from'] == 'human':
         | 
| 428 | 
            +
                            if image_token_str is None and '<image>' in msg['value']:
         | 
| 429 | 
            +
                                msg['value'] = msg['value'].replace('<image>', '')
         | 
| 430 | 
            +
                            if '<image>' in msg['value']:
         | 
| 431 | 
            +
                                msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
         | 
| 432 | 
            +
                            input += msg['value'].strip()
         | 
| 433 | 
            +
                        elif msg['from'] == 'gpt':
         | 
| 434 | 
            +
                            out_conversation.append({
         | 
| 435 | 
            +
                                'input': input,
         | 
| 436 | 
            +
                                'output': msg['value'].strip()
         | 
| 437 | 
            +
                            })
         | 
| 438 | 
            +
                            input = ''
         | 
| 439 | 
            +
                        else:
         | 
| 440 | 
            +
                            raise NotImplementedError
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    input_ids, labels = [], []
         | 
| 443 | 
            +
                    for i, single_turn_conversation in enumerate(out_conversation):
         | 
| 444 | 
            +
                        input = single_turn_conversation.get('input', '')
         | 
| 445 | 
            +
                        if input is None:
         | 
| 446 | 
            +
                            input = ''
         | 
| 447 | 
            +
                        input_text = self.template.INSTRUCTION.format(
         | 
| 448 | 
            +
                            input=input, round=i + 1)
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                        if i == 0:
         | 
| 451 | 
            +
                            if self._system != '' and self._system is not None:
         | 
| 452 | 
            +
                                system = self.template.SYSTEM.format(system=self._system)
         | 
| 453 | 
            +
                                input_text = system + input_text
         | 
| 454 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 455 | 
            +
                                input_text, add_special_tokens=True)
         | 
| 456 | 
            +
                        else:
         | 
| 457 | 
            +
                            input_encode = self.tokenizer.encode(
         | 
| 458 | 
            +
                                input_text, add_special_tokens=False)
         | 
| 459 | 
            +
                        input_ids += input_encode
         | 
| 460 | 
            +
                        labels += [IGNORE_INDEX] * len(input_encode)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                        output_text = single_turn_conversation.get('output', '')
         | 
| 463 | 
            +
                        if self.template.get('SUFFIX', None):
         | 
| 464 | 
            +
                            output_text += self.template.SUFFIX
         | 
| 465 | 
            +
                        output_encode = self.tokenizer.encode(
         | 
| 466 | 
            +
                            output_text, add_special_tokens=False)
         | 
| 467 | 
            +
                        input_ids += output_encode
         | 
| 468 | 
            +
                        labels += copy.deepcopy(output_encode)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    if len(input_ids) > self.max_length:
         | 
| 471 | 
            +
                        input_ids = input_ids[:self.max_length]
         | 
| 472 | 
            +
                        labels = labels[:self.max_length]
         | 
| 473 | 
            +
                        print_log(
         | 
| 474 | 
            +
                            f'Warning: input_ids length({len(input_ids)}) '
         | 
| 475 | 
            +
                            f'is longer than max_length, cut to {self.max_length}',
         | 
| 476 | 
            +
                            logger='current')
         | 
| 477 | 
            +
                    return {'input_ids': input_ids, 'labels': labels}
         | 
| 478 | 
            +
             | 
| 479 | 
            +
             | 
| 480 | 
            +
            if __name__ == '__main__':
         | 
| 481 | 
            +
                from transformers import CLIPImageProcessor, AutoTokenizer
         | 
| 482 | 
            +
                from third_parts.segment_anything.utils.transforms import ResizeLongestSide
         | 
| 483 | 
            +
                pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
         | 
| 484 | 
            +
                llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                tokenizer = dict(
         | 
| 487 | 
            +
                    type=AutoTokenizer.from_pretrained,
         | 
| 488 | 
            +
                    pretrained_model_name_or_path=llm_name_or_path)
         | 
| 489 | 
            +
                image_processor = dict(
         | 
| 490 | 
            +
                    type=CLIPImageProcessor.from_pretrained,
         | 
| 491 | 
            +
                    pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
         | 
| 492 | 
            +
                extra_image_processor = dict(
         | 
| 493 | 
            +
                    type=ResizeLongestSide,
         | 
| 494 | 
            +
                    target_length=1024,
         | 
| 495 | 
            +
                )
         | 
| 496 | 
            +
                from xtuner.utils.templates import PROMPT_TEMPLATE
         | 
| 497 | 
            +
                prompt_template = PROMPT_TEMPLATE.vicuna
         | 
| 498 | 
            +
                from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
         | 
| 499 | 
            +
                from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                dataset = LLaVADataset(
         | 
| 502 | 
            +
                    tokenizer=tokenizer,
         | 
| 503 | 
            +
                    data_path='data/llava_data/LLaVA-Instruct-150K/llava_instruct_150k.json',
         | 
| 504 | 
            +
                    prompt_template=prompt_template,
         | 
| 505 | 
            +
                    special_tokens=['[SEG]'],
         | 
| 506 | 
            +
                    image_folder='data/coco/train2017/',
         | 
| 507 | 
            +
                )
         | 
| 508 | 
            +
                for i in range(1000):
         | 
| 509 | 
            +
                    dataset[i]
         | 
    	
        projects/llava_sam2/deepspeed_zero2_sam2.json
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "gradient_accumulation_steps": "auto",
         | 
| 3 | 
            +
              "train_micro_batch_size_per_gpu": "auto",
         | 
| 4 | 
            +
              "gradient_clipping": "auto",
         | 
| 5 | 
            +
              "zero_allow_untested_optimizer": true,
         | 
| 6 | 
            +
              "zero_force_ds_cpu_optimizer": false,
         | 
| 7 | 
            +
              "zero_optimization": {
         | 
| 8 | 
            +
                "stage": 2,
         | 
| 9 | 
            +
                "overlap_comm": true,
         | 
| 10 | 
            +
                "allgather_bucket_size": 5368709120,
         | 
| 11 | 
            +
                "reduce_bucket_size": 5368709120,
         | 
| 12 | 
            +
                "reduce_scatter": true,
         | 
| 13 | 
            +
                "sub_group_size": 1e9,
         | 
| 14 | 
            +
                "contiguous_gradients": true,
         | 
| 15 | 
            +
                "allgather_partitions": true
         | 
| 16 | 
            +
              },
         | 
| 17 | 
            +
              "fp16": {
         | 
| 18 | 
            +
                "enabled": false,
         | 
| 19 | 
            +
                "initial_scale_power": 16
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              "bf16": {
         | 
| 22 | 
            +
                "enabled": true
         | 
| 23 | 
            +
              }
         | 
| 24 | 
            +
            }
         | 
    	
        projects/llava_sam2/gradio/app.py
    ADDED
    
    | @@ -0,0 +1,151 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            from projects.llava_sam2.gradio.app_utils import\
         | 
| 4 | 
            +
                process_markdown, show_mask_pred, description, preprocess_video,\
         | 
| 5 | 
            +
                show_mask_pred_video, image2video_and_save
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
         | 
| 9 | 
            +
                                      BitsAndBytesConfig, CLIPImageProcessor,
         | 
| 10 | 
            +
                                      CLIPVisionModel, GenerationConfig)
         | 
| 11 | 
            +
            import argparse
         | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            TORCH_DTYPE_MAP = dict(
         | 
| 15 | 
            +
                fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def parse_args(args):
         | 
| 18 | 
            +
                parser = argparse.ArgumentParser(description="Sa2VA Demo")
         | 
| 19 | 
            +
                parser.add_argument('hf_path', help='Sa2VA hf path.')
         | 
| 20 | 
            +
                return parser.parse_args(args)
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            def inference(image, video, follow_up, input_str):
         | 
| 23 | 
            +
                input_image = image
         | 
| 24 | 
            +
                if image is not None and (video is not None and os.path.exists(video)):
         | 
| 25 | 
            +
                    return image, video, "Error: Please only input a image or a video !!!"
         | 
| 26 | 
            +
                if image is None and (video is None or not os.path.exists(video)) and not follow_up:
         | 
| 27 | 
            +
                    return image, video, "Error: Please input a image or a video !!!"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                if not follow_up:
         | 
| 30 | 
            +
                    # reset
         | 
| 31 | 
            +
                    print('Log: History responses have been removed!')
         | 
| 32 | 
            +
                    global_infos.n_turn = 0
         | 
| 33 | 
            +
                    global_infos.inputs = ''
         | 
| 34 | 
            +
                    text = input_str
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    image = input_image
         | 
| 37 | 
            +
                    global_infos.image_for_show = image
         | 
| 38 | 
            +
                    global_infos.image = image
         | 
| 39 | 
            +
                    video = video
         | 
| 40 | 
            +
                    global_infos.video = video
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    if image is not None:
         | 
| 43 | 
            +
                        global_infos.input_type = "image"
         | 
| 44 | 
            +
                    else:
         | 
| 45 | 
            +
                        global_infos.input_type = "video"
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                else:
         | 
| 48 | 
            +
                    text = input_str
         | 
| 49 | 
            +
                    image = global_infos.image
         | 
| 50 | 
            +
                    video = global_infos.video
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                input_type = global_infos.input_type
         | 
| 53 | 
            +
                if input_type == "video":
         | 
| 54 | 
            +
                    video = preprocess_video(video, global_infos.inputs+input_str)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                past_text = global_infos.inputs
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if past_text == "" and "<image>" not in text:
         | 
| 59 | 
            +
                    text = "<image>" + text
         | 
| 60 | 
            +
                if input_type == "image":
         | 
| 61 | 
            +
                    input_dict = {
         | 
| 62 | 
            +
                        'image': image,
         | 
| 63 | 
            +
                        'text': text,
         | 
| 64 | 
            +
                        'past_text': past_text,
         | 
| 65 | 
            +
                        'mask_prompts': None,
         | 
| 66 | 
            +
                        'tokenizer': tokenizer,
         | 
| 67 | 
            +
                    }
         | 
| 68 | 
            +
                else:
         | 
| 69 | 
            +
                    input_dict = {
         | 
| 70 | 
            +
                        'video': video,
         | 
| 71 | 
            +
                        'text': text,
         | 
| 72 | 
            +
                        'past_text': past_text,
         | 
| 73 | 
            +
                        'mask_prompts': None,
         | 
| 74 | 
            +
                        'tokenizer': tokenizer,
         | 
| 75 | 
            +
                    }
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return_dict = sa2va_model.predict_forward(**input_dict)
         | 
| 78 | 
            +
                global_infos.inputs = return_dict["past_text"]
         | 
| 79 | 
            +
                print(return_dict['past_text'])
         | 
| 80 | 
            +
                if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len(
         | 
| 81 | 
            +
                        return_dict['prediction_masks']) != 0:
         | 
| 82 | 
            +
                    if input_type == "image":
         | 
| 83 | 
            +
                        image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],)
         | 
| 84 | 
            +
                        video_mask_show = global_infos.video
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        image_mask_show = None
         | 
| 87 | 
            +
                        video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],)
         | 
| 88 | 
            +
                        video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4")
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    image_mask_show = global_infos.image_for_show
         | 
| 91 | 
            +
                    video_mask_show = global_infos.video
         | 
| 92 | 
            +
                    selected_colors = []
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                predict = return_dict['prediction'].strip()
         | 
| 95 | 
            +
                global_infos.n_turn += 1
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                predict = process_markdown(predict, selected_colors)
         | 
| 98 | 
            +
                return image_mask_show, video_mask_show, predict
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            def init_models(args):
         | 
| 101 | 
            +
                model_path = args.hf_path
         | 
| 102 | 
            +
                model = AutoModel.from_pretrained(
         | 
| 103 | 
            +
                    model_path,
         | 
| 104 | 
            +
                    torch_dtype=torch.bfloat16,
         | 
| 105 | 
            +
                    low_cpu_mem_usage=True,
         | 
| 106 | 
            +
                    use_flash_attn=True,
         | 
| 107 | 
            +
                    trust_remote_code=True,
         | 
| 108 | 
            +
                ).eval().cuda()
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(
         | 
| 111 | 
            +
                    model_path,
         | 
| 112 | 
            +
                    trust_remote_code=True,
         | 
| 113 | 
            +
                )
         | 
| 114 | 
            +
                return model, tokenizer
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            class global_infos:
         | 
| 117 | 
            +
                inputs = ''
         | 
| 118 | 
            +
                n_turn = 0
         | 
| 119 | 
            +
                image_width = 0
         | 
| 120 | 
            +
                image_height = 0
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                image_for_show = None
         | 
| 123 | 
            +
                image = None
         | 
| 124 | 
            +
                video = None
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                input_type = "image" # "image" or "video"
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            if __name__ == "__main__":
         | 
| 129 | 
            +
                # get parse args and set models
         | 
| 130 | 
            +
                args = parse_args(sys.argv[1:])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                sa2va_model, tokenizer = \
         | 
| 133 | 
            +
                    init_models(args)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                demo = gr.Interface(
         | 
| 136 | 
            +
                    inference,
         | 
| 137 | 
            +
                    inputs=[
         | 
| 138 | 
            +
                        gr.Image(type="pil", label="Upload Image", height=360),
         | 
| 139 | 
            +
                        gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360),
         | 
| 140 | 
            +
                        gr.Checkbox(label="Follow up Question"),
         | 
| 141 | 
            +
                        gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),],
         | 
| 142 | 
            +
                    outputs=[
         | 
| 143 | 
            +
                        gr.Image(type="pil", label="Output Image"),
         | 
| 144 | 
            +
                        gr.Video(label="Output Video", show_download_button=True, format='mp4'),
         | 
| 145 | 
            +
                        gr.Markdown()],
         | 
| 146 | 
            +
                    theme=gr.themes.Soft(), allow_flagging="auto", description=description,
         | 
| 147 | 
            +
                    title='Sa2VA'
         | 
| 148 | 
            +
                )
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                demo.queue()
         | 
| 151 | 
            +
                demo.launch(share=True)
         | 
    	
        projects/llava_sam2/gradio/app_utils.py
    ADDED
    
    | @@ -0,0 +1,293 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            from PIL import Image
         | 
| 3 | 
            +
            import cv2
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            markdown_default = """
         | 
| 6 | 
            +
            <link href="https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap" rel="stylesheet">
         | 
| 7 | 
            +
            <style>
         | 
| 8 | 
            +
                    .highlighted-text {
         | 
| 9 | 
            +
                        font-family: 'Montserrat', sans-serif;
         | 
| 10 | 
            +
                        font-weight: 600;
         | 
| 11 | 
            +
                        font-size: 14px;
         | 
| 12 | 
            +
                        color: rgb(255, 255, 239);
         | 
| 13 | 
            +
                        background-color: rgb(225, 231, 254);
         | 
| 14 | 
            +
                        border-radius: 7px;
         | 
| 15 | 
            +
                        padding: 5px 7px;
         | 
| 16 | 
            +
                        display: inline-block;
         | 
| 17 | 
            +
                    }
         | 
| 18 | 
            +
                    .regular-text {
         | 
| 19 | 
            +
                        font-family: 'Montserrat', sans-serif;
         | 
| 20 | 
            +
                        font-weight: 400;
         | 
| 21 | 
            +
                        font-size: 14px;
         | 
| 22 | 
            +
                    }
         | 
| 23 | 
            +
                    .highlighted-response {
         | 
| 24 | 
            +
                        font-family: 'Montserrat', sans-serif;
         | 
| 25 | 
            +
                        font-weight: 600;
         | 
| 26 | 
            +
                        font-size: 14px;
         | 
| 27 | 
            +
                        border-radius: 6px;
         | 
| 28 | 
            +
                        padding: 3px 4px;
         | 
| 29 | 
            +
                        display: inline-block;
         | 
| 30 | 
            +
                    }
         | 
| 31 | 
            +
            </style>
         | 
| 32 | 
            +
            <span class="highlighted-text" style='color:rgb(107, 100, 239)'>Sa2VA</span>
         | 
| 33 | 
            +
            """
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            description = """
         | 
| 36 | 
            +
            **Usage** : <br>
         | 
| 37 | 
            +
             (1) For **Grounded Caption Generation** Interleaved Segmentation, input prompt like: *"Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer."* <br>
         | 
| 38 | 
            +
             (2) For **Segmentation Output**, input prompt like: *"Can you please segment xxx in the given image"* <br>
         | 
| 39 | 
            +
             (3) For **Image Captioning** VQA, input prompt like: *"Could you please give me a detailed description of the image?"* <br>
         | 
| 40 | 
            +
             (4) For **Image Conversation**, input arbitrary text instruction. <br>
         | 
| 41 | 
            +
            """
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            ONE_THIRD = 1.0/3.0
         | 
| 44 | 
            +
            ONE_SIXTH = 1.0/6.0
         | 
| 45 | 
            +
            TWO_THIRD = 2.0/3.0
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            def desaturate(rgb, factor=0.65):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                Desaturate an RGB color by a given factor.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                :param rgb: A tuple of (r, g, b) where each value is in [0, 255].
         | 
| 52 | 
            +
                :param factor: The factor by which to reduce the saturation.
         | 
| 53 | 
            +
                               0 means completely desaturated, 1 means original color.
         | 
| 54 | 
            +
                :return: A tuple of desaturated (r, g, b) values in [0, 255].
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                r, g, b = [x / 255.0 for x in rgb]
         | 
| 57 | 
            +
                h, l, s = rgb_to_hls(r, g, b)
         | 
| 58 | 
            +
                l = factor
         | 
| 59 | 
            +
                new_r, new_g, new_b = hls_to_rgb(h, l, s)
         | 
| 60 | 
            +
                return (int(new_r * 255), int(new_g * 255), int(new_b * 255))
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            def rgb_to_hls(r, g, b):
         | 
| 63 | 
            +
                maxc = max(r, g, b)
         | 
| 64 | 
            +
                minc = min(r, g, b)
         | 
| 65 | 
            +
                sumc = (maxc+minc)
         | 
| 66 | 
            +
                rangec = (maxc-minc)
         | 
| 67 | 
            +
                l = sumc/2.0
         | 
| 68 | 
            +
                if minc == maxc:
         | 
| 69 | 
            +
                    return 0.0, l, 0.0
         | 
| 70 | 
            +
                if l <= 0.5:
         | 
| 71 | 
            +
                    s = rangec / sumc
         | 
| 72 | 
            +
                else:
         | 
| 73 | 
            +
                    s = rangec / (2.0-sumc)
         | 
| 74 | 
            +
                rc = (maxc-r) / rangec
         | 
| 75 | 
            +
                gc = (maxc-g) / rangec
         | 
| 76 | 
            +
                bc = (maxc-b) / rangec
         | 
| 77 | 
            +
                if r == maxc:
         | 
| 78 | 
            +
                    h = bc-gc
         | 
| 79 | 
            +
                elif g == maxc:
         | 
| 80 | 
            +
                    h = 2.0+rc-bc
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    h = 4.0+gc-rc
         | 
| 83 | 
            +
                h = (h/6.0) % 1.0
         | 
| 84 | 
            +
                return h, l, s
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def hls_to_rgb(h, l, s):
         | 
| 87 | 
            +
                if s == 0.0:
         | 
| 88 | 
            +
                    return l, l, l
         | 
| 89 | 
            +
                if l <= 0.5:
         | 
| 90 | 
            +
                    m2 = l * (1.0+s)
         | 
| 91 | 
            +
                else:
         | 
| 92 | 
            +
                    m2 = l+s-(l*s)
         | 
| 93 | 
            +
                m1 = 2.0*l - m2
         | 
| 94 | 
            +
                return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            def _v(m1, m2, hue):
         | 
| 97 | 
            +
                hue = hue % 1.0
         | 
| 98 | 
            +
                if hue < ONE_SIXTH:
         | 
| 99 | 
            +
                    return m1 + (m2-m1)*hue*6.0
         | 
| 100 | 
            +
                if hue < 0.5:
         | 
| 101 | 
            +
                    return m2
         | 
| 102 | 
            +
                if hue < TWO_THIRD:
         | 
| 103 | 
            +
                    return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
         | 
| 104 | 
            +
                return m1
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            def process_markdown(output_str, colors):
         | 
| 107 | 
            +
                output_str = output_str.replace("\n", "").replace("  ", " ").replace("<s>", "")\
         | 
| 108 | 
            +
                    .replace("<|im_end|>", '').replace("<|end|>", "")
         | 
| 109 | 
            +
                output_str = output_str.split("ASSISTANT: ")[-1]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                # markdown_out = output_str.replace('[SEG]', '')
         | 
| 112 | 
            +
                markdown_out = output_str
         | 
| 113 | 
            +
                markdown_out = markdown_out.replace(
         | 
| 114 | 
            +
                    "<p>", "<span class='highlighted-response' style='background-color:rgb[COLOR]'>"
         | 
| 115 | 
            +
                )
         | 
| 116 | 
            +
                markdown_out = markdown_out.replace("</p>", "</span>")
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                for color in colors:
         | 
| 119 | 
            +
                    markdown_out = markdown_out.replace("[COLOR]", str(desaturate(tuple(color))), 1)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                markdown_out = f""" 
         | 
| 122 | 
            +
                {markdown_out}
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                markdown_out = markdown_default + "<p><span class='regular-text'>" + markdown_out
         | 
| 125 | 
            +
                return markdown_out
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            def show_mask_pred(image, masks):
         | 
| 128 | 
            +
                masks = [mask[:1] for mask in masks]
         | 
| 129 | 
            +
                masks = np.concatenate(masks, axis=0)  # (n, h, w)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                selected_colors = []
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
         | 
| 134 | 
            +
                          (255, 255, 0), (255, 0, 255), (0, 255, 255),
         | 
| 135 | 
            +
                          (128, 128, 255), [255, 192, 203],  # Pink
         | 
| 136 | 
            +
                          [165, 42, 42],    # Brown
         | 
| 137 | 
            +
                          [255, 165, 0],    # Orange
         | 
| 138 | 
            +
                          [128, 0, 128],     # Purple
         | 
| 139 | 
            +
                          [0, 0, 128],       # Navy
         | 
| 140 | 
            +
                          [128, 0, 0],      # Maroon
         | 
| 141 | 
            +
                          [128, 128, 0],    # Olive
         | 
| 142 | 
            +
                          [70, 130, 180],   # Steel Blue
         | 
| 143 | 
            +
                          [173, 216, 230],  # Light Blue
         | 
| 144 | 
            +
                          [255, 192, 0],    # Gold
         | 
| 145 | 
            +
                          [255, 165, 165],  # Light Salmon
         | 
| 146 | 
            +
                          [255, 20, 147],   # Deep Pink
         | 
| 147 | 
            +
                          ]
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                _mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                for i, mask in enumerate(masks):
         | 
| 152 | 
            +
                    color = colors[i % len(colors)]
         | 
| 153 | 
            +
                    selected_colors.append(color)
         | 
| 154 | 
            +
                    _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
         | 
| 155 | 
            +
                    _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
         | 
| 156 | 
            +
                    _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
             | 
| 159 | 
            +
                image = np.array(image)
         | 
| 160 | 
            +
                image = image * 0.5 + _mask_image * 0.5
         | 
| 161 | 
            +
                image = image.astype(np.uint8)
         | 
| 162 | 
            +
                return image, selected_colors
         | 
| 163 | 
            +
             | 
| 164 | 
            +
            def show_mask_pred_video(video, masks):
         | 
| 165 | 
            +
                ret_video = []
         | 
| 166 | 
            +
                selected_colors = []
         | 
| 167 | 
            +
                colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
         | 
| 168 | 
            +
                          (255, 255, 0), (255, 0, 255), (0, 255, 255),
         | 
| 169 | 
            +
                          (128, 128, 255), [255, 192, 203],  # Pink
         | 
| 170 | 
            +
                          [165, 42, 42],  # Brown
         | 
| 171 | 
            +
                          [255, 165, 0],  # Orange
         | 
| 172 | 
            +
                          [128, 0, 128],  # Purple
         | 
| 173 | 
            +
                          [0, 0, 128],  # Navy
         | 
| 174 | 
            +
                          [128, 0, 0],  # Maroon
         | 
| 175 | 
            +
                          [128, 128, 0],  # Olive
         | 
| 176 | 
            +
                          [70, 130, 180],  # Steel Blue
         | 
| 177 | 
            +
                          [173, 216, 230],  # Light Blue
         | 
| 178 | 
            +
                          [255, 192, 0],  # Gold
         | 
| 179 | 
            +
                          [255, 165, 165],  # Light Salmon
         | 
| 180 | 
            +
                          [255, 20, 147],  # Deep Pink
         | 
| 181 | 
            +
                          ]
         | 
| 182 | 
            +
                for i_frame in range(len(video)):
         | 
| 183 | 
            +
                    frame_masks = [mask[i_frame:i_frame+1] for mask in masks]
         | 
| 184 | 
            +
                    frame_masks = np.concatenate(frame_masks, axis=0)
         | 
| 185 | 
            +
                    _mask_image = np.zeros((frame_masks.shape[1], frame_masks.shape[2], 3), dtype=np.uint8)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    for i, mask in enumerate(frame_masks):
         | 
| 188 | 
            +
                        if i_frame == 0:
         | 
| 189 | 
            +
                            color = colors[i % len(colors)]
         | 
| 190 | 
            +
                            selected_colors.append(color)
         | 
| 191 | 
            +
                        else:
         | 
| 192 | 
            +
                            color = selected_colors[i]
         | 
| 193 | 
            +
                        _mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
         | 
| 194 | 
            +
                        _mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
         | 
| 195 | 
            +
                        _mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    image = np.array(video[i_frame])
         | 
| 198 | 
            +
                    image = image * 0.5 + _mask_image * 0.5
         | 
| 199 | 
            +
                    image = image.astype(np.uint8)
         | 
| 200 | 
            +
                    ret_video.append(image)
         | 
| 201 | 
            +
                return ret_video, selected_colors
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            def parse_visual_prompts(points):
         | 
| 204 | 
            +
                ret = {'points': [], 'boxes': []}
         | 
| 205 | 
            +
                for item in points:
         | 
| 206 | 
            +
                    if item[2] == 1.0:
         | 
| 207 | 
            +
                        ret['points'].append([item[0], item[1]])
         | 
| 208 | 
            +
                    elif item[2] == 2.0 or item[2] == 3.0:
         | 
| 209 | 
            +
                        ret['boxes'].append([item[0], item[1], item[3], item[4]])
         | 
| 210 | 
            +
                    else:
         | 
| 211 | 
            +
                        raise NotImplementedError
         | 
| 212 | 
            +
                return ret
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            def get_video_frames(video_path):
         | 
| 215 | 
            +
                cap = cv2.VideoCapture(video_path)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                if not cap.isOpened():
         | 
| 218 | 
            +
                    print("Error: Cannot open video file.")
         | 
| 219 | 
            +
                    return
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                frames = []
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                frame_id = 0
         | 
| 224 | 
            +
                while True:
         | 
| 225 | 
            +
                    ret, frame = cap.read()
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    if not ret:
         | 
| 228 | 
            +
                        break
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    frames.append(frame)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    frame_id += 1
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                cap.release()
         | 
| 235 | 
            +
                return frames
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            def get_frames_from_video(video_path, n_frames=5, sample_type="uniform"):
         | 
| 238 | 
            +
                frames = get_video_frames(video_path)
         | 
| 239 | 
            +
                if sample_type == "uniform":
         | 
| 240 | 
            +
                    stride = len(frames) / (n_frames + 1e-4)
         | 
| 241 | 
            +
                    ret = []
         | 
| 242 | 
            +
                    for i in range(n_frames):
         | 
| 243 | 
            +
                        idx = int(i * stride)
         | 
| 244 | 
            +
                        frame = frames[idx]
         | 
| 245 | 
            +
                        frame = frame[:, :, ::-1]
         | 
| 246 | 
            +
                        frame_image = Image.fromarray(frame).convert('RGB')
         | 
| 247 | 
            +
                        ret.append(frame_image)
         | 
| 248 | 
            +
                else:
         | 
| 249 | 
            +
                    ret = []
         | 
| 250 | 
            +
                    for frame in frames[:500]:
         | 
| 251 | 
            +
                        frame = frame[:, :, ::-1]
         | 
| 252 | 
            +
                        frame_image = Image.fromarray(frame).convert('RGB')
         | 
| 253 | 
            +
                        ret.append(frame_image)
         | 
| 254 | 
            +
                return ret
         | 
| 255 | 
            +
             | 
| 256 | 
            +
            def preprocess_video(video_path, text):
         | 
| 257 | 
            +
                if "Segment" in text or "segment" in text:
         | 
| 258 | 
            +
                    sample_type = 'begin'
         | 
| 259 | 
            +
                else:
         | 
| 260 | 
            +
                    sample_type = 'uniform'
         | 
| 261 | 
            +
                return get_frames_from_video(video_path, sample_type=sample_type)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            def image2video_and_save(frames, save_path):
         | 
| 264 | 
            +
                success = frames_to_video(frames, save_path)
         | 
| 265 | 
            +
                return save_path
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            def frames_to_video(
         | 
| 269 | 
            +
                    frames,
         | 
| 270 | 
            +
                    output_path: str,
         | 
| 271 | 
            +
                    fps: int = 24,
         | 
| 272 | 
            +
            ) -> bool:
         | 
| 273 | 
            +
                try:
         | 
| 274 | 
            +
                    frames = [frame[:, :, ::-1] for frame in frames]
         | 
| 275 | 
            +
                    # Use provided frame size or get from first frame
         | 
| 276 | 
            +
                    height, width = frames[0].shape[:2]
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    # Initialize video writer
         | 
| 279 | 
            +
                    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
         | 
| 280 | 
            +
                    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # Process each frame
         | 
| 283 | 
            +
                    for frame in frames:
         | 
| 284 | 
            +
                        out.write(frame)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Release video writer
         | 
| 287 | 
            +
                    out.release()
         | 
| 288 | 
            +
                    print(f"Video saved successfully to {output_path}")
         | 
| 289 | 
            +
                    return True
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                except Exception as e:
         | 
| 292 | 
            +
                    print(f"Error converting frames to video: {str(e)}")
         | 
| 293 | 
            +
                    return False
         | 
    	
        projects/llava_sam2/models/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .llava_sam2 import VideoLLaVASAMModel, VideoLLaVASAMModel_zero3
         | 
| 2 | 
            +
            from .sam2 import SAM2
         | 
| 3 | 
            +
            from .sam2_train import SAM2TrainRunner
         | 
    	
        projects/llava_sam2/models/extension/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .sam2_base import SAM2Base
         | 
    	
        projects/llava_sam2/models/extension/sam2_base.py
    ADDED
    
    | @@ -0,0 +1,281 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base
         | 
| 5 | 
            +
            from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class SAM2Base(_SAM2Base):
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def track_step(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    frame_idx,
         | 
| 13 | 
            +
                    is_init_cond_frame,
         | 
| 14 | 
            +
                    current_vision_feats,
         | 
| 15 | 
            +
                    current_vision_pos_embeds,
         | 
| 16 | 
            +
                    feat_sizes,
         | 
| 17 | 
            +
                    point_inputs,
         | 
| 18 | 
            +
                    mask_inputs,
         | 
| 19 | 
            +
                    output_dict,
         | 
| 20 | 
            +
                    num_frames,
         | 
| 21 | 
            +
                    track_in_reverse=False,  # tracking in reverse time order (for demo usage)
         | 
| 22 | 
            +
                    # Whether to run the memory encoder on the predicted masks. Sometimes we might want
         | 
| 23 | 
            +
                    # to skip the memory encoder with `run_mem_encoder=False`. For example,
         | 
| 24 | 
            +
                    # in demo we might call `track_step` multiple times for each user click,
         | 
| 25 | 
            +
                    # and only encode the memory when the user finalizes their clicks. And in ablation
         | 
| 26 | 
            +
                    # settings like SAM training on static images, we don't need the memory encoder.
         | 
| 27 | 
            +
                    run_mem_encoder=True,
         | 
| 28 | 
            +
                    # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
         | 
| 29 | 
            +
                    prev_sam_mask_logits=None,
         | 
| 30 | 
            +
                    ## Extension: LLM prompt
         | 
| 31 | 
            +
                    language_embd=None,
         | 
| 32 | 
            +
                ):
         | 
| 33 | 
            +
                    current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
         | 
| 34 | 
            +
                    # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
         | 
| 35 | 
            +
                    if len(current_vision_feats) > 1:
         | 
| 36 | 
            +
                        high_res_features = [
         | 
| 37 | 
            +
                            x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
         | 
| 38 | 
            +
                            for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
         | 
| 39 | 
            +
                        ]
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        high_res_features = None
         | 
| 42 | 
            +
                    if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
         | 
| 43 | 
            +
                        # When use_mask_input_as_output_without_sam=True, we directly output the mask input
         | 
| 44 | 
            +
                        # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
         | 
| 45 | 
            +
                        pix_feat = current_vision_feats[-1].permute(1, 2, 0)
         | 
| 46 | 
            +
                        pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
         | 
| 47 | 
            +
                        sam_outputs = self._use_mask_as_output(
         | 
| 48 | 
            +
                            pix_feat, high_res_features, mask_inputs
         | 
| 49 | 
            +
                        )
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        # fused the visual feature with previous memory features in the memory bank
         | 
| 52 | 
            +
                        pix_feat_with_mem = self._prepare_memory_conditioned_features(
         | 
| 53 | 
            +
                            frame_idx=frame_idx,
         | 
| 54 | 
            +
                            is_init_cond_frame=is_init_cond_frame,
         | 
| 55 | 
            +
                            current_vision_feats=current_vision_feats[-1:],
         | 
| 56 | 
            +
                            current_vision_pos_embeds=current_vision_pos_embeds[-1:],
         | 
| 57 | 
            +
                            feat_sizes=feat_sizes[-1:],
         | 
| 58 | 
            +
                            output_dict=output_dict,
         | 
| 59 | 
            +
                            num_frames=num_frames,
         | 
| 60 | 
            +
                            track_in_reverse=track_in_reverse,
         | 
| 61 | 
            +
                        )
         | 
| 62 | 
            +
                        # apply SAM-style segmentation head
         | 
| 63 | 
            +
                        # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
         | 
| 64 | 
            +
                        # e.g. in demo where such logits come from earlier interaction instead of correction sampling
         | 
| 65 | 
            +
                        # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
         | 
| 66 | 
            +
                        if prev_sam_mask_logits is not None:
         | 
| 67 | 
            +
                            assert point_inputs is not None and mask_inputs is None
         | 
| 68 | 
            +
                            mask_inputs = prev_sam_mask_logits
         | 
| 69 | 
            +
                        multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
         | 
| 70 | 
            +
                        sam_outputs = self._forward_sam_heads(
         | 
| 71 | 
            +
                            backbone_features=pix_feat_with_mem,
         | 
| 72 | 
            +
                            point_inputs=point_inputs,
         | 
| 73 | 
            +
                            mask_inputs=mask_inputs,
         | 
| 74 | 
            +
                            high_res_features=high_res_features,
         | 
| 75 | 
            +
                            multimask_output=multimask_output,
         | 
| 76 | 
            +
                            # Inject language Embed if possible
         | 
| 77 | 
            +
                            language_embd=language_embd,
         | 
| 78 | 
            +
                        )
         | 
| 79 | 
            +
                    (
         | 
| 80 | 
            +
                        _,
         | 
| 81 | 
            +
                        _,
         | 
| 82 | 
            +
                        _,
         | 
| 83 | 
            +
                        low_res_masks,
         | 
| 84 | 
            +
                        high_res_masks,
         | 
| 85 | 
            +
                        obj_ptr,
         | 
| 86 | 
            +
                        _,
         | 
| 87 | 
            +
                    ) = sam_outputs
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    current_out["pred_masks"] = low_res_masks
         | 
| 90 | 
            +
                    current_out["pred_masks_high_res"] = high_res_masks
         | 
| 91 | 
            +
                    current_out["obj_ptr"] = obj_ptr
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    # Finally run the memory encoder on the predicted mask to encode
         | 
| 94 | 
            +
                    # it into a new memory feature (that can be used in future frames)
         | 
| 95 | 
            +
                    if run_mem_encoder and self.num_maskmem > 0:
         | 
| 96 | 
            +
                        high_res_masks_for_mem_enc = high_res_masks
         | 
| 97 | 
            +
                        maskmem_features, maskmem_pos_enc = self._encode_new_memory(
         | 
| 98 | 
            +
                            current_vision_feats=current_vision_feats,
         | 
| 99 | 
            +
                            feat_sizes=feat_sizes,
         | 
| 100 | 
            +
                            pred_masks_high_res=high_res_masks_for_mem_enc,
         | 
| 101 | 
            +
                            is_mask_from_pts=(point_inputs is not None),
         | 
| 102 | 
            +
                        )
         | 
| 103 | 
            +
                        current_out["maskmem_features"] = maskmem_features
         | 
| 104 | 
            +
                        current_out["maskmem_pos_enc"] = maskmem_pos_enc
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        current_out["maskmem_features"] = None
         | 
| 107 | 
            +
                        current_out["maskmem_pos_enc"] = None
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    return current_out
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
                def _forward_sam_heads(
         | 
| 113 | 
            +
                    self,
         | 
| 114 | 
            +
                    backbone_features,
         | 
| 115 | 
            +
                    point_inputs=None,
         | 
| 116 | 
            +
                    mask_inputs=None,
         | 
| 117 | 
            +
                    high_res_features=None,
         | 
| 118 | 
            +
                    multimask_output=False,
         | 
| 119 | 
            +
                    ## Extension: LLM prompt
         | 
| 120 | 
            +
                    language_embd=None,
         | 
| 121 | 
            +
                ):
         | 
| 122 | 
            +
                    """
         | 
| 123 | 
            +
                    Forward SAM prompt encoders and mask heads.
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    Inputs:
         | 
| 126 | 
            +
                    - backbone_features: image features of [B, C, H, W] shape
         | 
| 127 | 
            +
                    - point_inputs: a dictionary with "point_coords" and "point_labels", where
         | 
| 128 | 
            +
                      1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
         | 
| 129 | 
            +
                         absolute pixel-unit coordinate in (x, y) format of the P input points
         | 
| 130 | 
            +
                      2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
         | 
| 131 | 
            +
                         positive clicks, 0 means negative clicks, and -1 means padding
         | 
| 132 | 
            +
                    - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
         | 
| 133 | 
            +
                      same spatial size as the image.
         | 
| 134 | 
            +
                    - high_res_features: either 1) None or 2) or a list of length 2 containing
         | 
| 135 | 
            +
                      two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
         | 
| 136 | 
            +
                      which will be used as high-resolution feature maps for SAM decoder.
         | 
| 137 | 
            +
                    - multimask_output: if it's True, we output 3 candidate masks and their 3
         | 
| 138 | 
            +
                      corresponding IoU estimates, and if it's False, we output only 1 mask and
         | 
| 139 | 
            +
                      its corresponding IoU estimate.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    Outputs:
         | 
| 142 | 
            +
                    - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
         | 
| 143 | 
            +
                      `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
         | 
| 144 | 
            +
                      output mask logits (before sigmoid) for the low-resolution masks, with 4x
         | 
| 145 | 
            +
                      the resolution (1/4 stride) of the input backbone_features.
         | 
| 146 | 
            +
                    - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
         | 
| 147 | 
            +
                      if `multimask_output=True` and M = 1 if `multimask_output=False`),
         | 
| 148 | 
            +
                      upsampled from the low-resolution masks, with shape size as the image
         | 
| 149 | 
            +
                      (stride is 1 pixel).
         | 
| 150 | 
            +
                    - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
         | 
| 151 | 
            +
                      if `multimask_output=False`), the estimated IoU of each output mask.
         | 
| 152 | 
            +
                    - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
         | 
| 153 | 
            +
                      If `multimask_output=True`, it's the mask with the highest IoU estimate.
         | 
| 154 | 
            +
                      If `multimask_output=False`, it's the same as `low_res_multimasks`.
         | 
| 155 | 
            +
                    - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
         | 
| 156 | 
            +
                      If `multimask_output=True`, it's the mask with the highest IoU estimate.
         | 
| 157 | 
            +
                      If `multimask_output=False`, it's the same as `high_res_multimasks`.
         | 
| 158 | 
            +
                    - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
         | 
| 159 | 
            +
                      based on the output token from the SAM mask decoder.
         | 
| 160 | 
            +
                    """
         | 
| 161 | 
            +
                    B = backbone_features.size(0)
         | 
| 162 | 
            +
                    device = backbone_features.device
         | 
| 163 | 
            +
                    assert backbone_features.size(1) == self.sam_prompt_embed_dim
         | 
| 164 | 
            +
                    assert backbone_features.size(2) == self.sam_image_embedding_size
         | 
| 165 | 
            +
                    assert backbone_features.size(3) == self.sam_image_embedding_size
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # a) Handle point prompts
         | 
| 168 | 
            +
                    if point_inputs is not None:
         | 
| 169 | 
            +
                        sam_point_coords = point_inputs["point_coords"]
         | 
| 170 | 
            +
                        sam_point_labels = point_inputs["point_labels"]
         | 
| 171 | 
            +
                        assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
         | 
| 172 | 
            +
                    else:
         | 
| 173 | 
            +
                        # If no points are provide, pad with an empty point (with label -1)
         | 
| 174 | 
            +
                        sam_point_coords = torch.zeros(B, 1, 2, device=device)
         | 
| 175 | 
            +
                        sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    # b) Handle mask prompts
         | 
| 178 | 
            +
                    if mask_inputs is not None:
         | 
| 179 | 
            +
                        # If mask_inputs is provided, downsize it into low-res mask input if needed
         | 
| 180 | 
            +
                        # and feed it as a dense mask prompt into the SAM mask encoder
         | 
| 181 | 
            +
                        assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
         | 
| 182 | 
            +
                        if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
         | 
| 183 | 
            +
                            sam_mask_prompt = F.interpolate(
         | 
| 184 | 
            +
                                mask_inputs.float(),
         | 
| 185 | 
            +
                                size=self.sam_prompt_encoder.mask_input_size,
         | 
| 186 | 
            +
                                align_corners=False,
         | 
| 187 | 
            +
                                mode="bilinear",
         | 
| 188 | 
            +
                                antialias=True,  # use antialias for downsampling
         | 
| 189 | 
            +
                            )
         | 
| 190 | 
            +
                        else:
         | 
| 191 | 
            +
                            sam_mask_prompt = mask_inputs
         | 
| 192 | 
            +
                    else:
         | 
| 193 | 
            +
                        # Otherwise, simply feed None (and SAM's prompt encoder will add
         | 
| 194 | 
            +
                        # a learned `no_mask_embed` to indicate no mask input in this case).
         | 
| 195 | 
            +
                        sam_mask_prompt = None
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
         | 
| 198 | 
            +
                        points=(sam_point_coords, sam_point_labels),
         | 
| 199 | 
            +
                        boxes=None,
         | 
| 200 | 
            +
                        masks=sam_mask_prompt,
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    ## Extension: LLM prompt
         | 
| 204 | 
            +
                    if language_embd is not None:
         | 
| 205 | 
            +
                        # B N C
         | 
| 206 | 
            +
                        assert sparse_embeddings.size(0) == language_embd.size(0)
         | 
| 207 | 
            +
                        assert sparse_embeddings.size(2) == language_embd.size(2)
         | 
| 208 | 
            +
                        sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    (
         | 
| 211 | 
            +
                        low_res_multimasks,
         | 
| 212 | 
            +
                        ious,
         | 
| 213 | 
            +
                        sam_output_tokens,
         | 
| 214 | 
            +
                        object_score_logits,
         | 
| 215 | 
            +
                    ) = self.sam_mask_decoder(
         | 
| 216 | 
            +
                        image_embeddings=backbone_features,
         | 
| 217 | 
            +
                        image_pe=self.sam_prompt_encoder.get_dense_pe(),
         | 
| 218 | 
            +
                        sparse_prompt_embeddings=sparse_embeddings,
         | 
| 219 | 
            +
                        dense_prompt_embeddings=dense_embeddings,
         | 
| 220 | 
            +
                        multimask_output=multimask_output,
         | 
| 221 | 
            +
                        repeat_image=False,  # the image is already batched
         | 
| 222 | 
            +
                        high_res_features=high_res_features,
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
                    if self.pred_obj_scores:
         | 
| 225 | 
            +
                        is_obj_appearing = object_score_logits > 0
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        # Mask used for spatial memories is always a *hard* choice between obj and no obj,
         | 
| 228 | 
            +
                        # consistent with the actual mask prediction
         | 
| 229 | 
            +
                        # print('Do torch.where !!!')
         | 
| 230 | 
            +
                        # low_res_multimasks = torch.where(
         | 
| 231 | 
            +
                        #     is_obj_appearing[:, None, None],
         | 
| 232 | 
            +
                        #     low_res_multimasks,
         | 
| 233 | 
            +
                        #     NO_OBJ_SCORE,
         | 
| 234 | 
            +
                        # )
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    # convert masks from possibly bfloat16 (or float16) to float32
         | 
| 237 | 
            +
                    # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
         | 
| 238 | 
            +
                    low_res_multimasks = low_res_multimasks.float()
         | 
| 239 | 
            +
                    high_res_multimasks = F.interpolate(
         | 
| 240 | 
            +
                        low_res_multimasks,
         | 
| 241 | 
            +
                        size=(self.image_size, self.image_size),
         | 
| 242 | 
            +
                        mode="bilinear",
         | 
| 243 | 
            +
                        align_corners=False,
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    sam_output_token = sam_output_tokens[:, 0]
         | 
| 247 | 
            +
                    if multimask_output:
         | 
| 248 | 
            +
                        # take the best mask prediction (with the highest IoU estimation)
         | 
| 249 | 
            +
                        best_iou_inds = torch.argmax(ious, dim=-1)
         | 
| 250 | 
            +
                        batch_inds = torch.arange(B, device=device)
         | 
| 251 | 
            +
                        low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
         | 
| 252 | 
            +
                        high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
         | 
| 253 | 
            +
                        if sam_output_tokens.size(1) > 1:
         | 
| 254 | 
            +
                            sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
         | 
| 255 | 
            +
                    else:
         | 
| 256 | 
            +
                        low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    # Extract object pointer from the SAM output token (with occlusion handling)
         | 
| 259 | 
            +
                    obj_ptr = self.obj_ptr_proj(sam_output_token)
         | 
| 260 | 
            +
                    if self.pred_obj_scores:
         | 
| 261 | 
            +
                        # Allow *soft* no obj ptr, unlike for masks
         | 
| 262 | 
            +
                        if self.soft_no_obj_ptr:
         | 
| 263 | 
            +
                            # Only hard possible with gt
         | 
| 264 | 
            +
                            assert not self.teacher_force_obj_scores_for_mem
         | 
| 265 | 
            +
                            lambda_is_obj_appearing = object_score_logits.sigmoid()
         | 
| 266 | 
            +
                        else:
         | 
| 267 | 
            +
                            lambda_is_obj_appearing = is_obj_appearing.float()
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                        if self.fixed_no_obj_ptr:
         | 
| 270 | 
            +
                            obj_ptr = lambda_is_obj_appearing * obj_ptr
         | 
| 271 | 
            +
                        obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    return (
         | 
| 274 | 
            +
                        low_res_multimasks,
         | 
| 275 | 
            +
                        high_res_multimasks,
         | 
| 276 | 
            +
                        ious,
         | 
| 277 | 
            +
                        low_res_masks,
         | 
| 278 | 
            +
                        high_res_masks,
         | 
| 279 | 
            +
                        obj_ptr,
         | 
| 280 | 
            +
                        object_score_logits,
         | 
| 281 | 
            +
                    )
         | 
 
			
