Spaces:
Running
on
Zero
Running
on
Zero
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- LICENSE +201 -0
- README.md +13 -12
- app.py +435 -0
- assets/teaser.png +3 -0
- examples/amber.png +3 -0
- examples/armour.png +3 -0
- examples/art.wav +3 -0
- examples/chris.png +3 -0
- examples/dream.mp3 +3 -0
- examples/fictional.wav +3 -0
- examples/fight.wav +3 -0
- examples/jacket.png +3 -0
- examples/naomi.png +3 -0
- examples/science.wav +0 -0
- examples/vangogh.jpg +3 -0
- humo/common/__init__.py +0 -0
- humo/common/config.py +107 -0
- humo/common/distributed/__init__.py +41 -0
- humo/common/distributed/advanced.py +484 -0
- humo/common/distributed/basic.py +143 -0
- humo/common/logger.py +44 -0
- humo/configs/inference/generate.yaml +78 -0
- humo/configs/inference/generate_1_7B.yaml +76 -0
- humo/configs/models/Wan_1.3B.yaml +17 -0
- humo/configs/models/Wan_1.3B_I2V.yaml +18 -0
- humo/configs/models/Wan_14B.yaml +17 -0
- humo/configs/models/Wan_14B_I2V.yaml +18 -0
- humo/generate.py +984 -0
- humo/generate_1_7B.py +622 -0
- humo/models/audio/audio_proj.py +87 -0
- humo/models/distributed/__init__.py +0 -0
- humo/models/distributed/dit_ulysses_sequence_parallel.py +270 -0
- humo/models/distributed/fsdp.py +42 -0
- humo/models/text/encoder.py +173 -0
- humo/models/utils/fm_solvers.py +857 -0
- humo/models/utils/fm_solvers_unipc.py +800 -0
- humo/models/utils/utils.py +58 -0
- humo/models/wan_modules/__init__.py +16 -0
- humo/models/wan_modules/attention.py +256 -0
- humo/models/wan_modules/clip.py +542 -0
- humo/models/wan_modules/model.py +619 -0
- humo/models/wan_modules/model_humo.py +803 -0
- humo/models/wan_modules/t5.py +525 -0
- humo/models/wan_modules/tokenizers.py +82 -0
- humo/models/wan_modules/vae.py +666 -0
- humo/models/wan_modules/xlm_roberta.py +170 -0
- humo/utils/audio_processor_whisper.py +173 -0
- humo/utils/wav2vec.py +218 -0
- main.py +28 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/amber.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/armour.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/art.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/chris.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/dream.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/fictional.wav filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/fight.wav filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/jacket.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
examples/naomi.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
examples/vangogh.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Bytedance
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: HuMo Local
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: HuMo [Local]
|
| 3 |
+
emoji: 👩🦱
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.47.2
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
short_description: Reference based video generation
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import uuid
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from huggingface_hub import snapshot_download, list_repo_files, hf_hub_download
|
| 12 |
+
import importlib, site
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Re-discover all .pth/.egg-link files
|
| 16 |
+
for sitedir in site.getsitepackages():
|
| 17 |
+
site.addsitedir(sitedir)
|
| 18 |
+
|
| 19 |
+
# Clear caches so importlib will pick up new modules
|
| 20 |
+
importlib.invalidate_caches()
|
| 21 |
+
|
| 22 |
+
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
| 23 |
+
|
| 24 |
+
flash_attention_installed = False
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
flash_attention_wheel = hf_hub_download(
|
| 28 |
+
repo_id="alexnasa/flash-attn-3",
|
| 29 |
+
repo_type="model",
|
| 30 |
+
filename="128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
sh(f"pip install {flash_attention_wheel}")
|
| 34 |
+
print("Attempting to download and install FlashAttention wheel...")
|
| 35 |
+
# sh("pip install flash-attn")
|
| 36 |
+
sh("pip install --no-build-isolation transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl")
|
| 37 |
+
|
| 38 |
+
# tell Python to re-scan site-packages now that the egg-link exists
|
| 39 |
+
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
| 40 |
+
|
| 41 |
+
flash_attention_installed = True
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"⚠️ Could not install FlashAttention: {e}")
|
| 45 |
+
print("Continuing without FlashAttention...")
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
te_wheel = hf_hub_download(
|
| 49 |
+
repo_id="alexnasa/transformer_engine_wheels",
|
| 50 |
+
repo_type="model",
|
| 51 |
+
filename="transformer_engine-2.5.0+f05f12c9-cp310-cp310-linux_x86_64.whl",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
sh(f"pip install {te_wheel}")
|
| 55 |
+
print("Attempting to download and install Transformer Engine wheel...")
|
| 56 |
+
|
| 57 |
+
# tell Python to re-scan site-packages now that the egg-link exists
|
| 58 |
+
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"⚠️ Could not install Transformer Engine : {e}")
|
| 62 |
+
print("Continuing without Transformer Engine ...")
|
| 63 |
+
|
| 64 |
+
import torch
|
| 65 |
+
print(f"Torch version: {torch.__version__}")
|
| 66 |
+
print(f"FlashAttention available: {flash_attention_installed}")
|
| 67 |
+
|
| 68 |
+
import tempfile
|
| 69 |
+
from pathlib import Path
|
| 70 |
+
from torch._inductor.runtime.runtime_utils import cache_dir as _inductor_cache_dir
|
| 71 |
+
from huggingface_hub import HfApi
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
snapshot_download(repo_id="bytedance-research/HuMo", local_dir="./weights/HuMo")
|
| 75 |
+
snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir="./weights/Wan2.1-T2V-1.3B")
|
| 76 |
+
snapshot_download(repo_id="openai/whisper-large-v3", local_dir="./weights/whisper-large-v3")
|
| 77 |
+
|
| 78 |
+
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
|
| 79 |
+
|
| 80 |
+
path_to_insert = "humo"
|
| 81 |
+
if path_to_insert not in sys.path:
|
| 82 |
+
sys.path.insert(0, path_to_insert)
|
| 83 |
+
|
| 84 |
+
from common.config import load_config, create_object
|
| 85 |
+
|
| 86 |
+
config = load_config(
|
| 87 |
+
"./humo/configs/inference/generate.yaml",
|
| 88 |
+
[
|
| 89 |
+
"dit.sp_size=1",
|
| 90 |
+
"generation.frames=97",
|
| 91 |
+
"generation.scale_t=5.5",
|
| 92 |
+
"generation.scale_a=5.0",
|
| 93 |
+
"generation.mode=TIA",
|
| 94 |
+
"generation.height=480",
|
| 95 |
+
"generation.width=832",
|
| 96 |
+
],
|
| 97 |
+
)
|
| 98 |
+
runner = create_object(config)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", f"{os.getcwd()}/torchinductor_space") # or another writable path
|
| 102 |
+
|
| 103 |
+
def restore_inductor_cache_from_hub(repo_id: str, filename: str = "torch_compile_cache.zip",
|
| 104 |
+
path_in_repo: str = "inductor_cache", repo_type: str = "model",
|
| 105 |
+
hf_token: str | None = None):
|
| 106 |
+
cache_root = Path(_inductor_cache_dir()).resolve()
|
| 107 |
+
cache_root.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
zip_path = hf_hub_download(repo_id=repo_id, filename=f"{path_in_repo}/{filename}",
|
| 109 |
+
repo_type=repo_type, token=hf_token)
|
| 110 |
+
shutil.unpack_archive(zip_path, extract_dir=str(cache_root))
|
| 111 |
+
print(f"✓ Restored cache into {cache_root}")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# restore_inductor_cache_from_hub("alexnasa/humo-compiled")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_duration(prompt_text, steps, image_file, audio_file_path, tea_cache_l1_thresh, max_duration, session_id):
|
| 118 |
+
|
| 119 |
+
return calculate_required_time(steps, max_duration)
|
| 120 |
+
|
| 121 |
+
def calculate_required_time(steps, max_duration):
|
| 122 |
+
|
| 123 |
+
warmup_s = 60
|
| 124 |
+
|
| 125 |
+
max_duration_duration_mapping = {
|
| 126 |
+
1: 8,
|
| 127 |
+
2: 8,
|
| 128 |
+
3: 11,
|
| 129 |
+
4: 20,
|
| 130 |
+
5: 30,
|
| 131 |
+
}
|
| 132 |
+
each_step_s = max_duration_duration_mapping[max_duration]
|
| 133 |
+
duration_s = (each_step_s * steps) + warmup_s
|
| 134 |
+
|
| 135 |
+
print(f'estimated duration:{duration_s}')
|
| 136 |
+
|
| 137 |
+
return int(duration_s)
|
| 138 |
+
|
| 139 |
+
def get_required_time_string(steps, max_duration):
|
| 140 |
+
|
| 141 |
+
duration_s = calculate_required_time(steps, max_duration)
|
| 142 |
+
duration_m = duration_s / 60
|
| 143 |
+
|
| 144 |
+
return f"<center>⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)</center>"
|
| 145 |
+
|
| 146 |
+
def update_required_time(steps, max_duration):
|
| 147 |
+
|
| 148 |
+
return get_required_time_string(steps, max_duration)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def generate_scene(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration = 2, session_id = None):
|
| 152 |
+
|
| 153 |
+
print(image_paths)
|
| 154 |
+
prompt_text_check = (prompt_text or "").strip()
|
| 155 |
+
if not prompt_text_check:
|
| 156 |
+
raise gr.Error("Please enter a prompt.")
|
| 157 |
+
|
| 158 |
+
if not audio_file_path and not image_paths:
|
| 159 |
+
raise gr.Error("Please provide a reference image or a lipsync audio.")
|
| 160 |
+
|
| 161 |
+
return run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh, max_duration, session_id)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def upload_inductor_cache_to_hub(
|
| 166 |
+
repo_id: str,
|
| 167 |
+
path_in_repo: str = "inductor_cache",
|
| 168 |
+
repo_type: str = "model", # or "dataset" if you prefer
|
| 169 |
+
hf_token: str | None = None,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Zips the current TorchInductor cache and uploads it to the given repo path.
|
| 173 |
+
Assumes the model was already run once with torch.compile() so the cache exists.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
cache_dir = Path(_inductor_cache_dir()).resolve()
|
| 177 |
+
if not cache_dir.exists():
|
| 178 |
+
raise FileNotFoundError(f"TorchInductor cache not found at {cache_dir}. "
|
| 179 |
+
"Run a compiled model once to populate it.")
|
| 180 |
+
|
| 181 |
+
# Create a zip archive of the entire cache directory
|
| 182 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 183 |
+
archive_base = Path(tmpdir) / "torch_compile_cache"
|
| 184 |
+
archive_path = shutil.make_archive(str(archive_base), "zip", root_dir=str(cache_dir))
|
| 185 |
+
archive_path = Path(archive_path)
|
| 186 |
+
|
| 187 |
+
# Upload to Hub
|
| 188 |
+
api = HfApi(token=hf_token)
|
| 189 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
|
| 190 |
+
# Put each artifact under path_in_repo, including a tiny metadata stamp for traceability
|
| 191 |
+
# Upload the zip
|
| 192 |
+
dest_path = f"{path_in_repo}/{archive_path.name}"
|
| 193 |
+
api.upload_file(
|
| 194 |
+
path_or_fileobj=str(archive_path),
|
| 195 |
+
path_in_repo=dest_path,
|
| 196 |
+
repo_id=repo_id,
|
| 197 |
+
repo_type=repo_type,
|
| 198 |
+
)
|
| 199 |
+
# Upload a small metadata file (optional but handy)
|
| 200 |
+
meta_txt = (
|
| 201 |
+
f"pytorch={torch.__version__}\n"
|
| 202 |
+
f"inductor_cache_dir={cache_dir}\n"
|
| 203 |
+
f"cuda_available={torch.cuda.is_available()}\n"
|
| 204 |
+
f"cuda_device={torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'}\n"
|
| 205 |
+
)
|
| 206 |
+
api.upload_file(
|
| 207 |
+
path_or_fileobj=meta_txt.encode(),
|
| 208 |
+
path_in_repo=f"{path_in_repo}/INDUCTOR_CACHE_METADATA.txt",
|
| 209 |
+
repo_id=repo_id,
|
| 210 |
+
repo_type=repo_type,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
print("✔ Uploaded TorchInductor cache to the Hub.")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@spaces.GPU(duration=get_duration)
|
| 217 |
+
def run_pipeline(prompt_text, steps, image_paths, audio_file_path, tea_cache_l1_thresh = 0.0, max_duration = 2, session_id = None):
|
| 218 |
+
|
| 219 |
+
if session_id is None:
|
| 220 |
+
session_id = uuid.uuid4().hex
|
| 221 |
+
|
| 222 |
+
inference_mode = "TIA"
|
| 223 |
+
|
| 224 |
+
# Validate inputs
|
| 225 |
+
prompt_text = (prompt_text or "").strip()
|
| 226 |
+
if not prompt_text:
|
| 227 |
+
raise gr.Error("Please enter a prompt.")
|
| 228 |
+
|
| 229 |
+
if not audio_file_path and not image_paths:
|
| 230 |
+
raise gr.Error("Please provide a reference image or a lipsync audio.")
|
| 231 |
+
|
| 232 |
+
if not audio_file_path:
|
| 233 |
+
inference_mode = "TI"
|
| 234 |
+
audio_path = None
|
| 235 |
+
else:
|
| 236 |
+
audio_path = audio_file_path if isinstance(audio_file_path, str) else getattr(audio_file_path, "name", str(audio_file_path))
|
| 237 |
+
|
| 238 |
+
if not image_paths:
|
| 239 |
+
inference_mode = "TA"
|
| 240 |
+
img_paths = None
|
| 241 |
+
else:
|
| 242 |
+
img_paths = [image_data[0] for image_data in image_paths]
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# Prepare output
|
| 246 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
| 247 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 248 |
+
|
| 249 |
+
# Random filename
|
| 250 |
+
filename = f"gen_{uuid.uuid4().hex[:10]}"
|
| 251 |
+
width, height = 832, 480
|
| 252 |
+
|
| 253 |
+
duration_frame_mapping = {
|
| 254 |
+
1:25,
|
| 255 |
+
2:45,
|
| 256 |
+
3:70,
|
| 257 |
+
4:97,
|
| 258 |
+
5:129
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
# Run inference
|
| 262 |
+
runner.inference_loop(
|
| 263 |
+
prompt_text,
|
| 264 |
+
img_paths,
|
| 265 |
+
audio_path,
|
| 266 |
+
output_dir,
|
| 267 |
+
filename,
|
| 268 |
+
inference_mode,
|
| 269 |
+
width,
|
| 270 |
+
height,
|
| 271 |
+
steps,
|
| 272 |
+
frames = int(duration_frame_mapping[max_duration]),
|
| 273 |
+
tea_cache_l1_thresh = tea_cache_l1_thresh,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Return resulting video path
|
| 277 |
+
video_path = os.path.join(output_dir, f"{filename}.mp4")
|
| 278 |
+
if os.path.exists(video_path):
|
| 279 |
+
|
| 280 |
+
# upload_inductor_cache_to_hub("alexnasa/humo-compiled")
|
| 281 |
+
|
| 282 |
+
return video_path
|
| 283 |
+
else:
|
| 284 |
+
candidates = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".mp4")]
|
| 285 |
+
if candidates:
|
| 286 |
+
return max(candidates, key=lambda p: os.path.getmtime(p))
|
| 287 |
+
return None
|
| 288 |
+
|
| 289 |
+
css = """
|
| 290 |
+
#col-container {
|
| 291 |
+
margin: 0 auto;
|
| 292 |
+
width: 100%;
|
| 293 |
+
max-width: 720px;
|
| 294 |
+
}
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
def cleanup(request: gr.Request):
|
| 298 |
+
|
| 299 |
+
sid = request.session_hash
|
| 300 |
+
if sid:
|
| 301 |
+
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
|
| 302 |
+
shutil.rmtree(d1, ignore_errors=True)
|
| 303 |
+
|
| 304 |
+
def start_session(request: gr.Request):
|
| 305 |
+
|
| 306 |
+
return request.session_hash
|
| 307 |
+
|
| 308 |
+
with gr.Blocks(css=css) as demo:
|
| 309 |
+
|
| 310 |
+
session_state = gr.State()
|
| 311 |
+
demo.load(start_session, outputs=[session_state])
|
| 312 |
+
|
| 313 |
+
with gr.Sidebar(width=400):
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
gr.HTML(
|
| 317 |
+
"""
|
| 318 |
+
<div style="text-align: center;">
|
| 319 |
+
<p style="font-size:16px; display: inline; margin: 0;">
|
| 320 |
+
<strong>HuMo</strong> – Human-Centric Video Generation via Collaborative Multi-Modal Conditioning
|
| 321 |
+
</p>
|
| 322 |
+
<a href="https://github.com/Phantom-video/HuMo" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
| 323 |
+
[Github]
|
| 324 |
+
</a>
|
| 325 |
+
</div>
|
| 326 |
+
"""
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
gr.Markdown("**REFERENCE IMAGES**")
|
| 330 |
+
|
| 331 |
+
img_input = gr.Gallery(
|
| 332 |
+
show_label=False,
|
| 333 |
+
label="",
|
| 334 |
+
interactive=True,
|
| 335 |
+
rows=1, columns=3, object_fit="contain", height="280",
|
| 336 |
+
file_types=['image']
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
gr.Markdown("**LIPSYNC AUDIO**")
|
| 340 |
+
|
| 341 |
+
audio_input = gr.Audio(
|
| 342 |
+
sources=["upload"],
|
| 343 |
+
show_label=False,
|
| 344 |
+
type="filepath",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
gr.Markdown("**SETTINGS**")
|
| 348 |
+
|
| 349 |
+
default_steps = 10
|
| 350 |
+
default_max_duration = 2
|
| 351 |
+
|
| 352 |
+
max_duration = gr.Slider(minimum=2, maximum=5, value=default_max_duration, step=1, label="Max Duration")
|
| 353 |
+
steps_input = gr.Slider(minimum=5, maximum=50, value=default_steps, step=5, label="Diffusion Steps")
|
| 354 |
+
tea_cache_l1_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Cache", visible=False)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
with gr.Column(elem_id="col-container"):
|
| 359 |
+
|
| 360 |
+
gr.HTML(
|
| 361 |
+
"""
|
| 362 |
+
<div style="text-align: center;">
|
| 363 |
+
<strong>HF Space by:</strong>
|
| 364 |
+
<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
| 365 |
+
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
|
| 366 |
+
</a>
|
| 367 |
+
</div>
|
| 368 |
+
"""
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
video_output = gr.Video(show_label=False)
|
| 372 |
+
|
| 373 |
+
gr.Markdown("<center><h2>PROMPT</h2></center>")
|
| 374 |
+
|
| 375 |
+
prompt_tb = gr.Textbox(
|
| 376 |
+
show_label=False,
|
| 377 |
+
lines=5,
|
| 378 |
+
placeholder="Describe the scene and the person talking....",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
gr.Markdown("")
|
| 382 |
+
time_required = gr.Markdown(get_required_time_string(default_steps, default_max_duration))
|
| 383 |
+
run_btn = gr.Button("🎬 Action", variant="primary")
|
| 384 |
+
|
| 385 |
+
gr.Examples(
|
| 386 |
+
examples=[
|
| 387 |
+
|
| 388 |
+
[
|
| 389 |
+
"A handheld tracking shot follows a female warrior walking through a cave. Her determined eyes are locked straight ahead. She speaks with intensity.",
|
| 390 |
+
5,
|
| 391 |
+
["./examples/naomi.png"],
|
| 392 |
+
"./examples/dream.mp3",
|
| 393 |
+
],
|
| 394 |
+
|
| 395 |
+
[
|
| 396 |
+
"A reddish-brown haired and bearded man sits pensively against swirling blue-and-white brushstrokes, dressed in a blue coat and dark waistcoat. The artistic backdrop and his thoughtful pose evoke a Post-Impressionist style in a studio-like setting.",
|
| 397 |
+
10,
|
| 398 |
+
["./examples/vangogh.jpg"],
|
| 399 |
+
"./examples/art.wav",
|
| 400 |
+
],
|
| 401 |
+
|
| 402 |
+
[
|
| 403 |
+
"A handheld tracking shot follows a female through a science lab. Her determined eyes are locked straight ahead. The clip is in black and white and patchy as she is explaining something to someone standing opposite her",
|
| 404 |
+
10,
|
| 405 |
+
["./examples/naomi.png"],
|
| 406 |
+
"./examples/science.wav",
|
| 407 |
+
],
|
| 408 |
+
|
| 409 |
+
[
|
| 410 |
+
"A woman with long, wavy dark hair looking at a person sitting opposite her whilst holding a book, wearing a leather jacket, long-sleeved jacket with a semi purple color one seen on a photo. Warm, window-like light bathes her figure, highlighting the outfit's elegant design and her graceful movements.",
|
| 411 |
+
50,
|
| 412 |
+
["./examples/amber.png", "./examples/jacket.png"],
|
| 413 |
+
"./examples/fictional.mp3",
|
| 414 |
+
],
|
| 415 |
+
|
| 416 |
+
],
|
| 417 |
+
inputs=[prompt_tb, steps_input, img_input, audio_input],
|
| 418 |
+
outputs=[video_output],
|
| 419 |
+
fn=run_pipeline,
|
| 420 |
+
cache_examples=True,
|
| 421 |
+
)
|
| 422 |
+
max_duration.change(update_required_time, [steps_input, max_duration], time_required)
|
| 423 |
+
steps_input.change(update_required_time, [steps_input, max_duration], time_required)
|
| 424 |
+
|
| 425 |
+
run_btn.click(
|
| 426 |
+
fn=generate_scene,
|
| 427 |
+
inputs=[prompt_tb, steps_input, img_input, audio_input, tea_cache_l1_thresh, max_duration, session_state],
|
| 428 |
+
outputs=[video_output],
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
demo.unload(cleanup)
|
| 434 |
+
demo.queue()
|
| 435 |
+
demo.launch(ssr_mode=False)
|
assets/teaser.png
ADDED
|
Git LFS Details
|
examples/amber.png
ADDED
|
Git LFS Details
|
examples/armour.png
ADDED
|
Git LFS Details
|
examples/art.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72c75df8e93a107e262ea9b002a66e72d3c1cd2084bce1474a31d8afffd0b651
|
| 3 |
+
size 114254
|
examples/chris.png
ADDED
|
Git LFS Details
|
examples/dream.mp3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:27248fd9e8f29bd60ccb1163b8df3c6f2630734f358aa3362ffe67e8148e0eb1
|
| 3 |
+
size 108275
|
examples/fictional.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31b550e6433ea44a0642dee90c326664ff4f568fec184170001f834597b3ad23
|
| 3 |
+
size 167084
|
examples/fight.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8dbee86c85e992ac6d17820a3730bf753fc9bf5bac6b8a470f84b7e98a64221a
|
| 3 |
+
size 264782
|
examples/jacket.png
ADDED
|
Git LFS Details
|
examples/naomi.png
ADDED
|
Git LFS Details
|
examples/science.wav
ADDED
|
Binary file (82.5 kB). View file
|
|
|
examples/vangogh.jpg
ADDED
|
Git LFS Details
|
humo/common/__init__.py
ADDED
|
File without changes
|
humo/common/config.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/config.py
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Configuration utility functions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import importlib
|
| 20 |
+
from typing import Any, Callable, List, Union
|
| 21 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 22 |
+
|
| 23 |
+
OmegaConf.register_new_resolver("eval", eval)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
|
| 27 |
+
"""
|
| 28 |
+
Load a configuration. Will resolve inheritance.
|
| 29 |
+
"""
|
| 30 |
+
config = OmegaConf.load(path)
|
| 31 |
+
if argv is not None:
|
| 32 |
+
config_argv = OmegaConf.from_dotlist(argv)
|
| 33 |
+
config = OmegaConf.merge(config, config_argv)
|
| 34 |
+
config = resolve_recursive(config, resolve_inheritance)
|
| 35 |
+
return config
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def resolve_recursive(
|
| 39 |
+
config: Any,
|
| 40 |
+
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
|
| 41 |
+
) -> Any:
|
| 42 |
+
config = resolver(config)
|
| 43 |
+
if isinstance(config, DictConfig):
|
| 44 |
+
for k in config.keys():
|
| 45 |
+
v = config.get(k)
|
| 46 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 47 |
+
config[k] = resolve_recursive(v, resolver)
|
| 48 |
+
if isinstance(config, ListConfig):
|
| 49 |
+
for i in range(len(config)):
|
| 50 |
+
v = config.get(i)
|
| 51 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 52 |
+
config[i] = resolve_recursive(v, resolver)
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
|
| 57 |
+
"""
|
| 58 |
+
Recursively resolve inheritance if the config contains:
|
| 59 |
+
__inherit__: path/to/parent.yaml.
|
| 60 |
+
"""
|
| 61 |
+
if isinstance(config, DictConfig):
|
| 62 |
+
inherit = config.pop("__inherit__", None)
|
| 63 |
+
if inherit:
|
| 64 |
+
assert isinstance(inherit, str)
|
| 65 |
+
inherit = load_config(inherit)
|
| 66 |
+
if len(config.keys()) > 0:
|
| 67 |
+
config = OmegaConf.merge(inherit, config)
|
| 68 |
+
else:
|
| 69 |
+
config = inherit
|
| 70 |
+
return config
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def import_item(path: str, name: str) -> Any:
|
| 74 |
+
"""
|
| 75 |
+
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
|
| 76 |
+
"""
|
| 77 |
+
return getattr(importlib.import_module(path), name)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def create_object(config: DictConfig) -> Any:
|
| 81 |
+
"""
|
| 82 |
+
Create an object from config.
|
| 83 |
+
The config is expected to contains the following:
|
| 84 |
+
__object__:
|
| 85 |
+
path: path.to.module
|
| 86 |
+
name: MyClass
|
| 87 |
+
args: as_config | as_params (default to as_config)
|
| 88 |
+
"""
|
| 89 |
+
item = import_item(
|
| 90 |
+
path=config.__object__.path,
|
| 91 |
+
name=config.__object__.name,
|
| 92 |
+
)
|
| 93 |
+
args = config.__object__.get("args", "as_config")
|
| 94 |
+
if args == "as_config":
|
| 95 |
+
return item(config)
|
| 96 |
+
if args == "as_params":
|
| 97 |
+
config = OmegaConf.to_object(config)
|
| 98 |
+
config.pop("__object__")
|
| 99 |
+
return item(**config)
|
| 100 |
+
raise NotImplementedError(f"Unknown args type: {args}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def create_dataset(path: str, *args, **kwargs) -> Any:
|
| 104 |
+
"""
|
| 105 |
+
Create a dataset. Requires the file to contain a "create_dataset" function.
|
| 106 |
+
"""
|
| 107 |
+
return import_item(path, "create_dataset")(*args, **kwargs)
|
humo/common/distributed/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Distributed package.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from .basic import (
|
| 20 |
+
barrier_if_distributed,
|
| 21 |
+
convert_to_ddp,
|
| 22 |
+
get_device,
|
| 23 |
+
get_global_rank,
|
| 24 |
+
get_local_rank,
|
| 25 |
+
get_world_size,
|
| 26 |
+
init_torch,
|
| 27 |
+
meta_param_init_fn,
|
| 28 |
+
meta_non_persistent_buffer_init_fn
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
__all__ = [
|
| 32 |
+
"barrier_if_distributed",
|
| 33 |
+
"convert_to_ddp",
|
| 34 |
+
"get_device",
|
| 35 |
+
"get_global_rank",
|
| 36 |
+
"get_local_rank",
|
| 37 |
+
"get_world_size",
|
| 38 |
+
"init_torch",
|
| 39 |
+
"meta_param_init_fn",
|
| 40 |
+
"meta_non_persistent_buffer_init_fn",
|
| 41 |
+
]
|
humo/common/distributed/advanced.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Advanced distributed functions for sequence parallel.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
|
| 24 |
+
from .basic import get_global_rank, get_world_size
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_DATA_PARALLEL_GROUP = None
|
| 28 |
+
_SEQUENCE_PARALLEL_GROUP = None
|
| 29 |
+
_SEQUENCE_PARALLEL_CPU_GROUP = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
_CFG_PARALLEL_GROUP = None
|
| 33 |
+
_CFG_PARALLEL_CPU_GROUP = None
|
| 34 |
+
|
| 35 |
+
def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
|
| 36 |
+
"""
|
| 37 |
+
Get data parallel process group.
|
| 38 |
+
"""
|
| 39 |
+
return _DATA_PARALLEL_GROUP
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
| 43 |
+
"""
|
| 44 |
+
Get sequence parallel process group.
|
| 45 |
+
"""
|
| 46 |
+
return _SEQUENCE_PARALLEL_GROUP
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
|
| 50 |
+
"""
|
| 51 |
+
Get sequence parallel CPU process group.
|
| 52 |
+
"""
|
| 53 |
+
return _SEQUENCE_PARALLEL_CPU_GROUP
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_data_parallel_rank() -> int:
|
| 57 |
+
"""
|
| 58 |
+
Get data parallel rank.
|
| 59 |
+
"""
|
| 60 |
+
group = get_data_parallel_group()
|
| 61 |
+
return dist.get_rank(group) if group else get_global_rank()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_data_parallel_world_size() -> int:
|
| 65 |
+
"""
|
| 66 |
+
Get data parallel world size.
|
| 67 |
+
"""
|
| 68 |
+
group = get_data_parallel_group()
|
| 69 |
+
return dist.get_world_size(group) if group else get_world_size()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_sequence_parallel_rank() -> int:
|
| 73 |
+
"""
|
| 74 |
+
Get sequence parallel rank.
|
| 75 |
+
"""
|
| 76 |
+
group = get_sequence_parallel_group()
|
| 77 |
+
return dist.get_rank(group) if group else 0
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_sequence_parallel_world_size() -> int:
|
| 81 |
+
"""
|
| 82 |
+
Get sequence parallel world size.
|
| 83 |
+
"""
|
| 84 |
+
group = get_sequence_parallel_group()
|
| 85 |
+
return dist.get_world_size(group) if group else 1
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def init_unified_parallel(unified_parallel_size):
|
| 89 |
+
global _SEQUENCE_PARALLEL_GROUP
|
| 90 |
+
global _SEQUENCE_PARALLEL_CPU_GROUP
|
| 91 |
+
|
| 92 |
+
if unified_parallel_size == 1:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
assert dist.is_initialized()
|
| 96 |
+
world_size = dist.get_world_size()
|
| 97 |
+
rank = dist.get_rank()
|
| 98 |
+
assert world_size % unified_parallel_size == 0
|
| 99 |
+
data_parallel_size = world_size // unified_parallel_size
|
| 100 |
+
|
| 101 |
+
for i in range(data_parallel_size):
|
| 102 |
+
# build unified parallel group
|
| 103 |
+
start_rank = i * unified_parallel_size
|
| 104 |
+
end_rank = start_rank + unified_parallel_size
|
| 105 |
+
unified_parallel_ranks = range(start_rank, end_rank)
|
| 106 |
+
unified_parallel_group = dist.new_group(unified_parallel_ranks)
|
| 107 |
+
unified_parallel_cpu_group = dist.new_group(unified_parallel_ranks, backend="gloo")
|
| 108 |
+
if rank in unified_parallel_ranks:
|
| 109 |
+
_SEQUENCE_PARALLEL_GROUP = unified_parallel_group
|
| 110 |
+
_SEQUENCE_PARALLEL_CPU_GROUP = unified_parallel_cpu_group
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_unified_parallel_group():
|
| 114 |
+
global _SEQUENCE_PARALLEL_GROUP
|
| 115 |
+
return _SEQUENCE_PARALLEL_GROUP
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_unified_parallel_cpu_group():
|
| 119 |
+
global _SEQUENCE_PARALLEL_CPU_GROUP
|
| 120 |
+
return _SEQUENCE_PARALLEL_CPU_GROUP
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_unified_parallel_rank():
|
| 124 |
+
group = get_unified_parallel_group()
|
| 125 |
+
return dist.get_rank(group) if group else 0
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_unified_parallel_world_size():
|
| 129 |
+
group = get_unified_parallel_group()
|
| 130 |
+
return dist.get_world_size(group) if group else 1
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def is_unified_parallel_initialized():
|
| 134 |
+
group = get_unified_parallel_group()
|
| 135 |
+
return group is not None
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def pad_tensor(x: Tensor, dim: int, padding_size: int):
|
| 139 |
+
shape = list(x.shape)
|
| 140 |
+
shape[dim] = padding_size
|
| 141 |
+
pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
| 142 |
+
return torch.cat([x, pad], dim=dim)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Slice(torch.autograd.Function):
|
| 146 |
+
@staticmethod
|
| 147 |
+
def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor:
|
| 148 |
+
ctx.group = group
|
| 149 |
+
ctx.rank = dist.get_rank(group)
|
| 150 |
+
seq_world_size = dist.get_world_size(group)
|
| 151 |
+
ctx.seq_world_size = seq_world_size
|
| 152 |
+
ctx.dim = dim
|
| 153 |
+
ctx.scale_grad = scale_grad
|
| 154 |
+
dim_size = local_input.shape[dim]
|
| 155 |
+
if not ctx.group:
|
| 156 |
+
return local_input
|
| 157 |
+
return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
|
| 158 |
+
|
| 159 |
+
@staticmethod
|
| 160 |
+
def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
|
| 161 |
+
if not ctx.group:
|
| 162 |
+
return None, grad_output, None, None
|
| 163 |
+
dim_size = list(grad_output.size())
|
| 164 |
+
split_size = dim_size[0]
|
| 165 |
+
dim_size[0] = dim_size[0] * ctx.seq_world_size
|
| 166 |
+
output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
|
| 167 |
+
dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
|
| 168 |
+
if ctx.scale_grad:
|
| 169 |
+
output = output / ctx.seq_world_size
|
| 170 |
+
return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def gather_outputs(
|
| 174 |
+
x: Tensor,
|
| 175 |
+
gather_dim: int,
|
| 176 |
+
padding_dim: Optional[int] = None,
|
| 177 |
+
unpad_dim_size: Optional[int] = None,
|
| 178 |
+
scale_grad=True,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
A func to gather the outputs for the model result in sequence parallel
|
| 182 |
+
"""
|
| 183 |
+
group = get_unified_parallel_group()
|
| 184 |
+
if not group:
|
| 185 |
+
return x
|
| 186 |
+
x = Gather.apply(group, x, gather_dim, scale_grad)
|
| 187 |
+
if padding_dim is not None:
|
| 188 |
+
x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int):
|
| 193 |
+
"""
|
| 194 |
+
A func to remove the padding part of the tensor based on its original shape
|
| 195 |
+
"""
|
| 196 |
+
group = get_unified_parallel_group()
|
| 197 |
+
if group is None:
|
| 198 |
+
return x
|
| 199 |
+
sp_world = get_unified_parallel_world_size()
|
| 200 |
+
if unpadded_dim_size % sp_world == 0:
|
| 201 |
+
return x
|
| 202 |
+
padding_size = sp_world - (unpadded_dim_size % sp_world)
|
| 203 |
+
assert (padding_size + unpadded_dim_size) % sp_world == 0
|
| 204 |
+
return unpad_tensor(x, dim=dim, padding_size=padding_size)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def gather_seq_scatter_heads_qkv(
|
| 208 |
+
qkv_tensor: Tensor,
|
| 209 |
+
seq_dim: int,
|
| 210 |
+
unpadded_dim_size: Optional[int] = None,
|
| 211 |
+
restore_shape: bool = True,
|
| 212 |
+
async_op: bool = False,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
A func to sync splited qkv tensor
|
| 216 |
+
qkv_tensor: the tensor we want to do alltoall with. The last dim must
|
| 217 |
+
be the projection_idx, which we will split into 3 part. After
|
| 218 |
+
spliting, the gather idx will be projecttion_idx + 1
|
| 219 |
+
seq_dim: gather_dim for all2all comm
|
| 220 |
+
restore_shape: if True, output will has the same shape length as input
|
| 221 |
+
"""
|
| 222 |
+
group = get_unified_parallel_group()
|
| 223 |
+
if not group:
|
| 224 |
+
return qkv_tensor
|
| 225 |
+
world = get_unified_parallel_world_size()
|
| 226 |
+
orig_shape = qkv_tensor.shape
|
| 227 |
+
scatter_dim = qkv_tensor.dim()
|
| 228 |
+
bef_all2all_shape = list(orig_shape)
|
| 229 |
+
qkv_proj_dim = bef_all2all_shape[-1]
|
| 230 |
+
bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
|
| 231 |
+
qkv_tensor = qkv_tensor.view(bef_all2all_shape)
|
| 232 |
+
if async_op:
|
| 233 |
+
return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
|
| 234 |
+
else:
|
| 235 |
+
qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
|
| 236 |
+
|
| 237 |
+
if restore_shape:
|
| 238 |
+
out_shape = list(orig_shape)
|
| 239 |
+
out_shape[seq_dim] *= world
|
| 240 |
+
out_shape[-1] = qkv_proj_dim // world
|
| 241 |
+
qkv_tensor = qkv_tensor.view(out_shape)
|
| 242 |
+
|
| 243 |
+
# remove padding
|
| 244 |
+
if unpadded_dim_size and unpadded_dim_size % world != 0:
|
| 245 |
+
padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
|
| 246 |
+
qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
|
| 247 |
+
|
| 248 |
+
return qkv_tensor
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def gather_seq_scatter_double_head(
|
| 252 |
+
qkv_tensor: Tensor,
|
| 253 |
+
seq_dim: int,
|
| 254 |
+
unpadded_dim_size: Optional[int] = None,
|
| 255 |
+
restore_shape: bool = True,
|
| 256 |
+
async_op: bool = False,
|
| 257 |
+
):
|
| 258 |
+
"""
|
| 259 |
+
A func to sync splited qkv tensor
|
| 260 |
+
qkv_tensor: the tensor we want to do alltoall with. The last dim must
|
| 261 |
+
be the projection_idx, which we will split into 3 part. After
|
| 262 |
+
spliting, the gather idx will be projecttion_idx + 1
|
| 263 |
+
seq_dim: gather_dim for all2all comm
|
| 264 |
+
restore_shape: if True, output will has the same shape length as input
|
| 265 |
+
"""
|
| 266 |
+
qkv1_shape = qkv_tensor.shape
|
| 267 |
+
group = get_unified_parallel_group()
|
| 268 |
+
if not group:
|
| 269 |
+
return qkv_tensor
|
| 270 |
+
world = get_unified_parallel_world_size()
|
| 271 |
+
orig_shape = qkv_tensor.shape
|
| 272 |
+
scatter_dim = qkv_tensor.dim()
|
| 273 |
+
bef_all2all_shape = list(orig_shape)
|
| 274 |
+
qkv_proj_dim = bef_all2all_shape[-1]
|
| 275 |
+
bef_all2all_shape = bef_all2all_shape[:-1] + [2, qkv_proj_dim // 2]
|
| 276 |
+
qkv_tensor = qkv_tensor.view(bef_all2all_shape)
|
| 277 |
+
qkv2_shape = qkv_tensor.shape
|
| 278 |
+
if async_op:
|
| 279 |
+
return SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
|
| 280 |
+
else:
|
| 281 |
+
qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
|
| 282 |
+
qkv3_shape = qkv_tensor.shape
|
| 283 |
+
|
| 284 |
+
if restore_shape:
|
| 285 |
+
out_shape = list(orig_shape)
|
| 286 |
+
out_shape[seq_dim] *= world
|
| 287 |
+
out_shape[-1] = qkv_proj_dim // world
|
| 288 |
+
qkv_tensor = qkv_tensor.view(out_shape)
|
| 289 |
+
qkv4_shape = qkv_tensor.shape
|
| 290 |
+
|
| 291 |
+
# remove padding
|
| 292 |
+
if unpadded_dim_size and unpadded_dim_size % world != 0:
|
| 293 |
+
padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
|
| 294 |
+
qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
|
| 295 |
+
qkv5_shape = qkv_tensor.shape
|
| 296 |
+
|
| 297 |
+
return qkv_tensor
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class SeqAllToAll(torch.autograd.Function):
|
| 301 |
+
@staticmethod
|
| 302 |
+
def forward(
|
| 303 |
+
ctx: Any,
|
| 304 |
+
group: dist.ProcessGroup,
|
| 305 |
+
local_input: Tensor,
|
| 306 |
+
scatter_dim: int,
|
| 307 |
+
gather_dim: int,
|
| 308 |
+
async_op: bool,
|
| 309 |
+
) -> Tensor:
|
| 310 |
+
ctx.group = group
|
| 311 |
+
ctx.scatter_dim = scatter_dim
|
| 312 |
+
ctx.gather_dim = gather_dim
|
| 313 |
+
ctx.async_op = async_op
|
| 314 |
+
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
|
| 315 |
+
|
| 316 |
+
@staticmethod
|
| 317 |
+
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
| 318 |
+
if ctx.async_op:
|
| 319 |
+
input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
|
| 320 |
+
else:
|
| 321 |
+
input_t = grad_output[0]
|
| 322 |
+
return (
|
| 323 |
+
None,
|
| 324 |
+
all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
|
| 325 |
+
None,
|
| 326 |
+
None,
|
| 327 |
+
None,
|
| 328 |
+
None,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def all_to_all_tensor(
|
| 333 |
+
x: Tensor,
|
| 334 |
+
scatter_dim: int,
|
| 335 |
+
gather_dim: int,
|
| 336 |
+
group: dist.ProcessGroup,
|
| 337 |
+
async_op: bool = False,
|
| 338 |
+
):
|
| 339 |
+
if scatter_dim <= 1 and gather_dim <= 1:
|
| 340 |
+
return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op)
|
| 341 |
+
else:
|
| 342 |
+
return _all_to_all(x, scatter_dim, gather_dim, group, async_op) # 走这里
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _all_to_all(
|
| 346 |
+
local_input: Tensor,
|
| 347 |
+
scatter_dim: int,
|
| 348 |
+
gather_dim: int,
|
| 349 |
+
group: dist.ProcessGroup,
|
| 350 |
+
async_op: bool = False,
|
| 351 |
+
):
|
| 352 |
+
seq_world_size = dist.get_world_size(group)
|
| 353 |
+
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
| 354 |
+
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
| 355 |
+
comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
|
| 356 |
+
if async_op:
|
| 357 |
+
|
| 358 |
+
def wait():
|
| 359 |
+
comm.wait()
|
| 360 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
| 361 |
+
|
| 362 |
+
return wait
|
| 363 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def _all_to_all_single(x: Tensor, scatter_dim: int, gather_dim: int, group: dist.ProcessGroup, async_op: bool = False):
|
| 367 |
+
"""
|
| 368 |
+
A function to do all-to-all on the first two dim
|
| 369 |
+
"""
|
| 370 |
+
sp_world_size = dist.get_world_size(group)
|
| 371 |
+
assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!"
|
| 372 |
+
assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!"
|
| 373 |
+
if scatter_dim != 0:
|
| 374 |
+
gather_dim_bef = x.shape[gather_dim]
|
| 375 |
+
scatter_dim_bef = x.shape[scatter_dim]
|
| 376 |
+
x = (
|
| 377 |
+
x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
|
| 378 |
+
.transpose(0, 1)
|
| 379 |
+
.reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
|
| 380 |
+
.contiguous()
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
output = torch.empty_like(x)
|
| 384 |
+
comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op)
|
| 385 |
+
|
| 386 |
+
if async_op:
|
| 387 |
+
|
| 388 |
+
def wait():
|
| 389 |
+
comm.wait()
|
| 390 |
+
if scatter_dim == 0:
|
| 391 |
+
return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
|
| 392 |
+
else:
|
| 393 |
+
return output
|
| 394 |
+
|
| 395 |
+
return wait
|
| 396 |
+
|
| 397 |
+
if scatter_dim == 0:
|
| 398 |
+
output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
|
| 399 |
+
return output
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
|
| 403 |
+
"""
|
| 404 |
+
A func to sync attention result with alltoall in sequence parallel
|
| 405 |
+
"""
|
| 406 |
+
group = get_unified_parallel_group()
|
| 407 |
+
if not group:
|
| 408 |
+
return x
|
| 409 |
+
dim_size = x.size(seq_dim)
|
| 410 |
+
sp_world = get_unified_parallel_world_size()
|
| 411 |
+
if dim_size % sp_world != 0:
|
| 412 |
+
padding_size = sp_world - (dim_size % sp_world)
|
| 413 |
+
x = pad_tensor(x, seq_dim, padding_size)
|
| 414 |
+
return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def unpad_tensor(x: Tensor, dim: int, padding_size: int):
|
| 418 |
+
slc = [slice(None)] * len(x.shape)
|
| 419 |
+
slc[dim] = slice(0, -padding_size)
|
| 420 |
+
return x[slc]
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
class Gather(torch.autograd.Function):
|
| 424 |
+
@staticmethod
|
| 425 |
+
def forward(
|
| 426 |
+
ctx: Any,
|
| 427 |
+
group: dist.ProcessGroup,
|
| 428 |
+
local_input: Tensor,
|
| 429 |
+
dim: int,
|
| 430 |
+
grad_scale: Optional[bool] = False,
|
| 431 |
+
) -> Tensor:
|
| 432 |
+
ctx.group = group
|
| 433 |
+
ctx.rank = dist.get_rank(group)
|
| 434 |
+
ctx.dim = dim
|
| 435 |
+
ctx.grad_scale = grad_scale
|
| 436 |
+
seq_world_size = dist.get_world_size(group)
|
| 437 |
+
ctx.seq_world_size = seq_world_size
|
| 438 |
+
dim_size = list(local_input.size())
|
| 439 |
+
split_size = dim_size[0]
|
| 440 |
+
ctx.part_size = dim_size[dim]
|
| 441 |
+
dim_size[0] = dim_size[0] * seq_world_size
|
| 442 |
+
output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
|
| 443 |
+
dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
|
| 444 |
+
return torch.cat(output.split(split_size), dim=dim)
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
|
| 448 |
+
if ctx.grad_scale:
|
| 449 |
+
grad_output = grad_output * ctx.seq_world_size
|
| 450 |
+
return (
|
| 451 |
+
None,
|
| 452 |
+
grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
|
| 453 |
+
None,
|
| 454 |
+
None,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def slice_tensor(tensor, dim, start, end):
|
| 459 |
+
indices = slice(start, end)
|
| 460 |
+
return tensor[(slice(None),) * dim + (indices,)]
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def init_model_shard_cpu_group(sharding_strategy: str, device_mesh: Optional[Tuple] = None):
|
| 464 |
+
"""
|
| 465 |
+
Initialize CPU process group of model sharding.
|
| 466 |
+
"""
|
| 467 |
+
global _MODEL_SHARD_CPU_GROUP
|
| 468 |
+
assert dist.is_initialized()
|
| 469 |
+
world_size = dist.get_world_size()
|
| 470 |
+
rank = dist.get_rank()
|
| 471 |
+
if device_mesh is not None:
|
| 472 |
+
num_shards_per_group = device_mesh[1]
|
| 473 |
+
elif "HYBRID" in sharding_strategy:
|
| 474 |
+
num_shards_per_group = min(8, world_size)
|
| 475 |
+
else:
|
| 476 |
+
num_shards_per_group = world_size
|
| 477 |
+
num_groups = world_size // num_shards_per_group
|
| 478 |
+
for i in range(num_groups):
|
| 479 |
+
start_rank = i * num_shards_per_group
|
| 480 |
+
end_rank = (i + 1) * num_shards_per_group
|
| 481 |
+
ranks = range(start_rank, end_rank)
|
| 482 |
+
cpu_group = dist.new_group(ranks, backend="gloo")
|
| 483 |
+
if rank in ranks:
|
| 484 |
+
_MODEL_SHARD_CPU_GROUP = cpu_group
|
humo/common/distributed/basic.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/tree/main/common/distributed
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Distributed basic functions.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 24 |
+
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_global_rank() -> int:
|
| 28 |
+
"""
|
| 29 |
+
Get the global rank, the global index of the GPU.
|
| 30 |
+
"""
|
| 31 |
+
return int(os.environ.get("RANK", "0"))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_local_rank() -> int:
|
| 35 |
+
"""
|
| 36 |
+
Get the local rank, the local index of the GPU.
|
| 37 |
+
"""
|
| 38 |
+
return int(os.environ.get("LOCAL_RANK", "0"))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_world_size() -> int:
|
| 42 |
+
"""
|
| 43 |
+
Get the world size, the total amount of GPUs.
|
| 44 |
+
"""
|
| 45 |
+
return int(os.environ.get("WORLD_SIZE", "1"))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_device() -> torch.device:
|
| 49 |
+
"""
|
| 50 |
+
Get current rank device.
|
| 51 |
+
"""
|
| 52 |
+
return torch.device("cuda", get_local_rank())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def barrier_if_distributed(*args, **kwargs):
|
| 56 |
+
"""
|
| 57 |
+
Synchronizes all processes if under distributed context.
|
| 58 |
+
"""
|
| 59 |
+
if dist.is_initialized():
|
| 60 |
+
return dist.barrier(*args, **kwargs)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def init_torch(cudnn_benchmark=True):
|
| 64 |
+
"""
|
| 65 |
+
Common PyTorch initialization configuration.
|
| 66 |
+
"""
|
| 67 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 68 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 69 |
+
torch.backends.cudnn.benchmark = cudnn_benchmark
|
| 70 |
+
torch.cuda.set_device(get_local_rank())
|
| 71 |
+
dist.init_process_group(
|
| 72 |
+
backend="nccl",
|
| 73 |
+
rank=get_global_rank(),
|
| 74 |
+
world_size=get_world_size(),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
|
| 79 |
+
return DistributedDataParallel(
|
| 80 |
+
module=module,
|
| 81 |
+
device_ids=[get_local_rank()],
|
| 82 |
+
output_device=get_local_rank(),
|
| 83 |
+
**kwargs,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def meta_param_init_fn(module: nn.Module) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Used for model inited onto meta device.
|
| 90 |
+
Init meta param/buffer with empty tensor.
|
| 91 |
+
We don't care numerical correctness in this func.
|
| 92 |
+
FSDP will sync param/buffer state from rank0 to the other ranks.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
for submodule in module.modules():
|
| 97 |
+
for param_name, param in submodule.named_parameters(recurse=False):
|
| 98 |
+
if not _is_fsdp_flattened(param) and param.is_meta:
|
| 99 |
+
materialized_param = nn.Parameter(torch.empty_like(param, device="cpu"))
|
| 100 |
+
setattr(submodule, param_name, materialized_param)
|
| 101 |
+
for buffer_name, buffer in submodule.named_buffers(recurse=False):
|
| 102 |
+
if not _is_fsdp_flattened(buffer) and buffer.is_meta:
|
| 103 |
+
materialized_param = torch.empty_like(buffer, device="cpu")
|
| 104 |
+
setattr(submodule, buffer_name, materialized_param)
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
|
| 109 |
+
"""
|
| 110 |
+
Materialize meta device buffers that are not persistent in state_dict.
|
| 111 |
+
Handles special cases like RotaryEmbedding.freqs.
|
| 112 |
+
"""
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
for submodule in module.modules():
|
| 115 |
+
if hasattr(submodule, "freqs"):
|
| 116 |
+
freqs = getattr(submodule, "freqs")
|
| 117 |
+
if isinstance(freqs, torch.Tensor) and freqs.is_meta:
|
| 118 |
+
dim = submodule.dim
|
| 119 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 120 |
+
assert dim % 2 == 0
|
| 121 |
+
freqs = torch.outer(
|
| 122 |
+
torch.arange(max_seq_len),
|
| 123 |
+
1.0 / torch.pow(theta,
|
| 124 |
+
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 125 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 126 |
+
return freqs
|
| 127 |
+
|
| 128 |
+
dim = 5120 # 1536
|
| 129 |
+
num_heads = 40 # 12
|
| 130 |
+
# dim = 1536
|
| 131 |
+
# num_heads = 12
|
| 132 |
+
d = dim // num_heads
|
| 133 |
+
freqs_tensor = torch.cat([
|
| 134 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 135 |
+
rope_params(1024, 2 * (d // 6)),
|
| 136 |
+
rope_params(1024, 2 * (d // 6))
|
| 137 |
+
], dim=1).to(dtype=torch.cfloat, device="cpu")
|
| 138 |
+
|
| 139 |
+
setattr(submodule, "freqs", freqs_tensor)
|
| 140 |
+
print(f"Successfully materialized freqs for {submodule.__class__.__name__}")
|
| 141 |
+
|
| 142 |
+
assert not any(b.is_meta for n, b in module.named_buffers())
|
| 143 |
+
return module
|
humo/common/logger.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/blob/main/common/logger.py
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Logging utility functions.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Optional
|
| 22 |
+
|
| 23 |
+
from common.distributed import get_global_rank, get_local_rank, get_world_size
|
| 24 |
+
|
| 25 |
+
_default_handler = logging.StreamHandler(sys.stdout)
|
| 26 |
+
_default_handler.setFormatter(
|
| 27 |
+
logging.Formatter(
|
| 28 |
+
"%(asctime)s "
|
| 29 |
+
+ (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
|
| 30 |
+
+ (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
|
| 31 |
+
+ "[%(threadName).12s][%(name)s][%(levelname).5s] "
|
| 32 |
+
+ "%(message)s"
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
| 38 |
+
"""
|
| 39 |
+
Get a logger.
|
| 40 |
+
"""
|
| 41 |
+
logger = logging.getLogger(name)
|
| 42 |
+
logger.addHandler(_default_handler)
|
| 43 |
+
logger.setLevel(logging.INFO)
|
| 44 |
+
return logger
|
humo/configs/inference/generate.yaml
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: humo.generate
|
| 3 |
+
name: Generator
|
| 4 |
+
|
| 5 |
+
dit:
|
| 6 |
+
model:
|
| 7 |
+
__inherit__: humo/configs/models/Wan_14B_I2V.yaml
|
| 8 |
+
__object__:
|
| 9 |
+
path: humo.models.wan_modules.model_humo
|
| 10 |
+
name: WanModel
|
| 11 |
+
insert_audio: True
|
| 12 |
+
zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
|
| 13 |
+
zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
|
| 14 |
+
checkpoint_dir: ./weights/HuMo/HuMo-17B
|
| 15 |
+
compile: False
|
| 16 |
+
init_with_meta_device: True
|
| 17 |
+
gradient_checkpoint: True
|
| 18 |
+
fsdp:
|
| 19 |
+
sharding_strategy: _HYBRID_SHARD_ZERO2
|
| 20 |
+
sp_size: 1
|
| 21 |
+
dtype: bfloat16
|
| 22 |
+
|
| 23 |
+
vae:
|
| 24 |
+
checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
|
| 25 |
+
vae_stride: [ 4, 8, 8 ]
|
| 26 |
+
scaling_factor: 0.9152
|
| 27 |
+
compile: False
|
| 28 |
+
grouping: True
|
| 29 |
+
use_sample: False
|
| 30 |
+
dtype: bfloat16
|
| 31 |
+
|
| 32 |
+
text:
|
| 33 |
+
t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
|
| 34 |
+
t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
|
| 35 |
+
dropout: 0.1
|
| 36 |
+
dtype: bfloat16
|
| 37 |
+
fsdp:
|
| 38 |
+
enabled: True
|
| 39 |
+
sharding_strategy: HYBRID_SHARD
|
| 40 |
+
|
| 41 |
+
diffusion:
|
| 42 |
+
schedule:
|
| 43 |
+
type: lerp
|
| 44 |
+
T: 1000.0
|
| 45 |
+
sampler:
|
| 46 |
+
type: euler
|
| 47 |
+
prediction_type: v_lerp
|
| 48 |
+
timesteps:
|
| 49 |
+
training:
|
| 50 |
+
type: logitnormal
|
| 51 |
+
loc: 0.0
|
| 52 |
+
scale: 1.0
|
| 53 |
+
sampling:
|
| 54 |
+
type: uniform_trailing
|
| 55 |
+
steps: 50
|
| 56 |
+
shift: 5.0
|
| 57 |
+
|
| 58 |
+
audio:
|
| 59 |
+
vocal_separator: ./weights/HuMo/audio_separator/Kim_Vocal_2.onnx
|
| 60 |
+
wav2vec_model: ./weights/whisper-large-v3
|
| 61 |
+
|
| 62 |
+
generation:
|
| 63 |
+
mode: "TIA" # TA, TIA
|
| 64 |
+
extract_audio_feat: True
|
| 65 |
+
seed: 666666
|
| 66 |
+
frames: 97
|
| 67 |
+
fps: 25
|
| 68 |
+
height: 480 # 720 480
|
| 69 |
+
width: 832 # 1280 832
|
| 70 |
+
batch_size: 1
|
| 71 |
+
sequence_parallel: 8
|
| 72 |
+
output:
|
| 73 |
+
dir: ./output
|
| 74 |
+
# positive_prompt: ./examples/test_case.json
|
| 75 |
+
sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
| 76 |
+
scale_a: 5.5
|
| 77 |
+
scale_t: 5.0
|
| 78 |
+
step_change: 980
|
humo/configs/inference/generate_1_7B.yaml
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: humo.generate_1_7B
|
| 3 |
+
name: Generator
|
| 4 |
+
|
| 5 |
+
dit:
|
| 6 |
+
model:
|
| 7 |
+
__inherit__: humo/configs/models/Wan_1.3B.yaml
|
| 8 |
+
__object__:
|
| 9 |
+
path: humo.models.wan_modules.model_humo
|
| 10 |
+
name: WanModel
|
| 11 |
+
insert_audio: True
|
| 12 |
+
zero_vae_path: ./weights/HuMo/zero_vae_129frame.pt
|
| 13 |
+
zero_vae_720p_path: ./weights/HuMo/zero_vae_720p_161frame.pt
|
| 14 |
+
checkpoint_dir: ./weights/HuMo/HuMo-1.7B/ema.pth #./weights/HuMo/HuMo-17B
|
| 15 |
+
compile: False
|
| 16 |
+
init_with_meta_device: True
|
| 17 |
+
gradient_checkpoint: True
|
| 18 |
+
fsdp:
|
| 19 |
+
sharding_strategy: _HYBRID_SHARD_ZERO2
|
| 20 |
+
sp_size: 1
|
| 21 |
+
|
| 22 |
+
vae:
|
| 23 |
+
checkpoint: ./weights/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
|
| 24 |
+
vae_stride: [ 4, 8, 8 ]
|
| 25 |
+
scaling_factor: 0.9152
|
| 26 |
+
compile: False
|
| 27 |
+
grouping: True
|
| 28 |
+
use_sample: False
|
| 29 |
+
dtype: bfloat16
|
| 30 |
+
|
| 31 |
+
text:
|
| 32 |
+
t5_checkpoint: ./weights/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
|
| 33 |
+
t5_tokenizer: ./weights/Wan2.1-T2V-1.3B/google/umt5-xxl
|
| 34 |
+
dropout: 0.1
|
| 35 |
+
dtype: bfloat16
|
| 36 |
+
fsdp:
|
| 37 |
+
enabled: True
|
| 38 |
+
sharding_strategy: HYBRID_SHARD
|
| 39 |
+
|
| 40 |
+
diffusion:
|
| 41 |
+
schedule:
|
| 42 |
+
type: lerp
|
| 43 |
+
T: 1000.0
|
| 44 |
+
sampler:
|
| 45 |
+
type: euler
|
| 46 |
+
prediction_type: v_lerp
|
| 47 |
+
timesteps:
|
| 48 |
+
training:
|
| 49 |
+
type: logitnormal
|
| 50 |
+
loc: 0.0
|
| 51 |
+
scale: 1.0
|
| 52 |
+
sampling:
|
| 53 |
+
type: uniform_trailing
|
| 54 |
+
steps: 50
|
| 55 |
+
shift: 5.0
|
| 56 |
+
|
| 57 |
+
audio:
|
| 58 |
+
vocal_separator: ./weights/audio_separator/Kim_Vocal_2.onnx
|
| 59 |
+
wav2vec_model: ./weights/whisper-large-v3
|
| 60 |
+
|
| 61 |
+
generation:
|
| 62 |
+
mode: "TIA" # TA, TIA
|
| 63 |
+
extract_audio_feat: True
|
| 64 |
+
seed: 666666
|
| 65 |
+
frames: 97
|
| 66 |
+
fps: 25
|
| 67 |
+
height: 720 # 480
|
| 68 |
+
width: 1280 # 832
|
| 69 |
+
batch_size: 1
|
| 70 |
+
output:
|
| 71 |
+
dir: ./output
|
| 72 |
+
sample_neg_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
| 73 |
+
scale_t: 7.5
|
| 74 |
+
scale_i: 4.0
|
| 75 |
+
scale_a: 7.5
|
| 76 |
+
# step_change: 980
|
humo/configs/models/Wan_1.3B.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: ???
|
| 3 |
+
name: ???
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
text_len: 512
|
| 7 |
+
patch_size: [ 1, 2, 2 ]
|
| 8 |
+
dim: 1536
|
| 9 |
+
ffn_dim: 8960
|
| 10 |
+
freq_dim: 256
|
| 11 |
+
model_type: "t2v"
|
| 12 |
+
num_heads: 12
|
| 13 |
+
num_layers: 30
|
| 14 |
+
window_size: [ -1, -1 ]
|
| 15 |
+
qk_norm: True
|
| 16 |
+
cross_attn_norm: True
|
| 17 |
+
eps: 1e-6
|
humo/configs/models/Wan_1.3B_I2V.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: ???
|
| 3 |
+
name: ???
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
text_len: 512
|
| 7 |
+
patch_size: [ 1, 2, 2 ]
|
| 8 |
+
dim: 1536
|
| 9 |
+
ffn_dim: 8960
|
| 10 |
+
freq_dim: 256
|
| 11 |
+
in_dim: 36
|
| 12 |
+
model_type: "i2v"
|
| 13 |
+
num_heads: 12
|
| 14 |
+
num_layers: 30
|
| 15 |
+
window_size: [ -1, -1 ]
|
| 16 |
+
qk_norm: True
|
| 17 |
+
cross_attn_norm: True
|
| 18 |
+
eps: 1e-6
|
humo/configs/models/Wan_14B.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: ???
|
| 3 |
+
name: ???
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
text_len: 512
|
| 7 |
+
patch_size: [ 1, 2, 2 ]
|
| 8 |
+
dim: 5120
|
| 9 |
+
ffn_dim: 13824
|
| 10 |
+
freq_dim: 256
|
| 11 |
+
model_type: "t2v"
|
| 12 |
+
num_heads: 40
|
| 13 |
+
num_layers: 40
|
| 14 |
+
window_size: [ -1, -1 ]
|
| 15 |
+
qk_norm: True
|
| 16 |
+
cross_attn_norm: True
|
| 17 |
+
eps: 1e-6
|
humo/configs/models/Wan_14B_I2V.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: ???
|
| 3 |
+
name: ???
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
text_len: 512
|
| 7 |
+
patch_size: [ 1, 2, 2 ]
|
| 8 |
+
dim: 5120
|
| 9 |
+
ffn_dim: 13824
|
| 10 |
+
freq_dim: 256
|
| 11 |
+
in_dim: 36
|
| 12 |
+
model_type: "i2v"
|
| 13 |
+
num_heads: 40
|
| 14 |
+
num_layers: 40
|
| 15 |
+
window_size: [ -1, -1 ]
|
| 16 |
+
qk_norm: True
|
| 17 |
+
cross_attn_norm: True
|
| 18 |
+
eps: 1e-6
|
humo/generate.py
ADDED
|
@@ -0,0 +1,984 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Inference codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import os
|
| 17 |
+
import gc
|
| 18 |
+
import random
|
| 19 |
+
import sys
|
| 20 |
+
import mediapy
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.distributed as dist
|
| 24 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from omegaconf import OmegaConf
|
| 27 |
+
from PIL import Image, ImageOps
|
| 28 |
+
from torchvision.transforms import ToTensor
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
from torch.distributed.device_mesh import init_device_mesh
|
| 31 |
+
from torch.distributed.fsdp import (
|
| 32 |
+
BackwardPrefetch,
|
| 33 |
+
FullyShardedDataParallel,
|
| 34 |
+
MixedPrecision,
|
| 35 |
+
ShardingStrategy,
|
| 36 |
+
)
|
| 37 |
+
from common.distributed import (
|
| 38 |
+
get_device,
|
| 39 |
+
get_global_rank,
|
| 40 |
+
get_local_rank,
|
| 41 |
+
meta_param_init_fn,
|
| 42 |
+
meta_non_persistent_buffer_init_fn,
|
| 43 |
+
init_torch,
|
| 44 |
+
)
|
| 45 |
+
from common.distributed.advanced import (
|
| 46 |
+
init_unified_parallel,
|
| 47 |
+
get_unified_parallel_world_size,
|
| 48 |
+
get_sequence_parallel_rank,
|
| 49 |
+
init_model_shard_cpu_group,
|
| 50 |
+
)
|
| 51 |
+
from common.logger import get_logger
|
| 52 |
+
from common.config import create_object
|
| 53 |
+
from common.distributed import get_device, get_global_rank
|
| 54 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
| 55 |
+
from humo.models.wan_modules.t5 import T5EncoderModel
|
| 56 |
+
from humo.models.wan_modules.vae import WanVAE
|
| 57 |
+
from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
|
| 58 |
+
from contextlib import contextmanager
|
| 59 |
+
import torch.cuda.amp as amp
|
| 60 |
+
from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 61 |
+
from humo.utils.audio_processor_whisper import AudioProcessor
|
| 62 |
+
from humo.utils.wav2vec import linear_interpolation_fps
|
| 63 |
+
from torchao.quantization import quantize_
|
| 64 |
+
|
| 65 |
+
import torch._dynamo as dynamo
|
| 66 |
+
dynamo.config.capture_scalar_outputs = True
|
| 67 |
+
torch.set_float32_matmul_precision("high")
|
| 68 |
+
|
| 69 |
+
import torch
|
| 70 |
+
import torch.nn as nn
|
| 71 |
+
import transformer_engine.pytorch as te
|
| 72 |
+
|
| 73 |
+
image_transform = Compose([
|
| 74 |
+
ToTensor(),
|
| 75 |
+
Normalize(mean=0.5, std=0.5),
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
SIZE_CONFIGS = {
|
| 79 |
+
'720*1280': (720, 1280),
|
| 80 |
+
'1280*720': (1280, 720),
|
| 81 |
+
'480*832': (480, 832),
|
| 82 |
+
'832*480': (832, 480),
|
| 83 |
+
'1024*1024': (1024, 1024),
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
def clever_format(nums, format="%.2f"):
|
| 87 |
+
from typing import Iterable
|
| 88 |
+
if not isinstance(nums, Iterable):
|
| 89 |
+
nums = [nums]
|
| 90 |
+
clever_nums = []
|
| 91 |
+
for num in nums:
|
| 92 |
+
if num > 1e12:
|
| 93 |
+
clever_nums.append(format % (num / 1e12) + "T")
|
| 94 |
+
elif num > 1e9:
|
| 95 |
+
clever_nums.append(format % (num / 1e9) + "G")
|
| 96 |
+
elif num > 1e6:
|
| 97 |
+
clever_nums.append(format % (num / 1e6) + "M")
|
| 98 |
+
elif num > 1e3:
|
| 99 |
+
clever_nums.append(format % (num / 1e3) + "K")
|
| 100 |
+
else:
|
| 101 |
+
clever_nums.append(format % num + "B")
|
| 102 |
+
|
| 103 |
+
clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
|
| 104 |
+
|
| 105 |
+
return clever_nums
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# --- put near your imports ---
|
| 110 |
+
import torch
|
| 111 |
+
import torch.nn as nn
|
| 112 |
+
import contextlib
|
| 113 |
+
import transformer_engine.pytorch as te
|
| 114 |
+
|
| 115 |
+
# FP8 autocast compatibility for different TE versions
|
| 116 |
+
try:
|
| 117 |
+
# Preferred modern API
|
| 118 |
+
from transformer_engine.pytorch import fp8_autocast
|
| 119 |
+
try:
|
| 120 |
+
# Newer TE: use recipe-based API
|
| 121 |
+
from transformer_engine.common.recipe import DelayedScaling, Format
|
| 122 |
+
def make_fp8_ctx(enabled: bool = True):
|
| 123 |
+
if not enabled:
|
| 124 |
+
return contextlib.nullcontext()
|
| 125 |
+
fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
|
| 126 |
+
return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
|
| 127 |
+
except Exception:
|
| 128 |
+
# Very old variant that might still accept fp8_format directly
|
| 129 |
+
def make_fp8_ctx(enabled: bool = True):
|
| 130 |
+
# If TE doesn't have FP8Format, just no-op
|
| 131 |
+
if not hasattr(te, "FP8Format"):
|
| 132 |
+
return contextlib.nullcontext()
|
| 133 |
+
return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
|
| 134 |
+
except Exception:
|
| 135 |
+
# TE not present or totally incompatible — no-op
|
| 136 |
+
def make_fp8_ctx(enabled: bool = True):
|
| 137 |
+
return contextlib.nullcontext()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# TE sometimes exposes Linear at different paths; this normalizes it.
|
| 141 |
+
try:
|
| 142 |
+
TELinear = te.Linear
|
| 143 |
+
except AttributeError: # very old layouts
|
| 144 |
+
from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
|
| 145 |
+
|
| 146 |
+
# --- near imports ---
|
| 147 |
+
import torch
|
| 148 |
+
import torch.nn as nn
|
| 149 |
+
import transformer_engine.pytorch as te
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
TELinear = te.Linear
|
| 153 |
+
except AttributeError:
|
| 154 |
+
from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
|
| 155 |
+
|
| 156 |
+
import torch
|
| 157 |
+
import torch.nn as nn
|
| 158 |
+
import transformer_engine.pytorch as te
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
TELinear = te.Linear
|
| 162 |
+
except AttributeError:
|
| 163 |
+
from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
|
| 164 |
+
|
| 165 |
+
def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
|
| 166 |
+
"""
|
| 167 |
+
Allow TE only where it's shape-safe & beneficial.
|
| 168 |
+
Skip small/special layers (time/timestep/pos embeds, heads).
|
| 169 |
+
Enforce multiples of 16 for in/out features (FP8 kernel friendly).
|
| 170 |
+
Also skip very small projections likely to see M=1.
|
| 171 |
+
"""
|
| 172 |
+
blocked_keywords = (
|
| 173 |
+
"time_embedding", "timestep", "time_embed",
|
| 174 |
+
"time_projection", "pos_embedding", "pos_embed",
|
| 175 |
+
"to_logits", "logits", "final_proj", "proj_out", "output_projection",
|
| 176 |
+
)
|
| 177 |
+
if any(k in fullname for k in blocked_keywords):
|
| 178 |
+
return False
|
| 179 |
+
|
| 180 |
+
# TE FP8 kernels like K, N divisible by 16
|
| 181 |
+
if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
# Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
|
| 185 |
+
if lin.in_features < 512 or lin.out_features < 512:
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
# Whitelist: only convert inside transformer blocks if you know their prefix
|
| 189 |
+
# This further reduces risk of catching special heads elsewhere.
|
| 190 |
+
allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
|
| 191 |
+
if not any(tok in fullname for tok in allowed_context):
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
return True
|
| 195 |
+
|
| 196 |
+
@torch.no_grad()
|
| 197 |
+
def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
|
| 198 |
+
for name, child in list(module.named_children()):
|
| 199 |
+
full = f"{_prefix}.{name}" if _prefix else name
|
| 200 |
+
convert_linears_to_te_fp8(child, allow_pred, full)
|
| 201 |
+
|
| 202 |
+
if isinstance(child, nn.Linear):
|
| 203 |
+
if allow_pred is not None and not allow_pred(full, child):
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
te_lin = TELinear(
|
| 207 |
+
in_features=child.in_features,
|
| 208 |
+
out_features=child.out_features,
|
| 209 |
+
bias=(child.bias is not None),
|
| 210 |
+
params_dtype=torch.bfloat16,
|
| 211 |
+
).to(child.weight.device)
|
| 212 |
+
|
| 213 |
+
te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
|
| 214 |
+
if child.bias is not None:
|
| 215 |
+
te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
|
| 216 |
+
|
| 217 |
+
setattr(module, name, te_lin)
|
| 218 |
+
return module
|
| 219 |
+
|
| 220 |
+
class Generator():
|
| 221 |
+
def __init__(self, config: DictConfig):
|
| 222 |
+
self.config = config.copy()
|
| 223 |
+
OmegaConf.set_readonly(self.config, True)
|
| 224 |
+
self.logger = get_logger(self.__class__.__name__)
|
| 225 |
+
|
| 226 |
+
# init_torch(cudnn_benchmark=False)
|
| 227 |
+
self.configure_models()
|
| 228 |
+
|
| 229 |
+
def entrypoint(self):
|
| 230 |
+
|
| 231 |
+
self.inference_loop()
|
| 232 |
+
|
| 233 |
+
def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
|
| 234 |
+
device_mesh = None
|
| 235 |
+
fsdp_strategy = ShardingStrategy[sharding_strategy]
|
| 236 |
+
if (
|
| 237 |
+
fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
|
| 238 |
+
and device_mesh_config is not None
|
| 239 |
+
):
|
| 240 |
+
device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
|
| 241 |
+
return device_mesh, fsdp_strategy
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def configure_models(self):
|
| 245 |
+
self.configure_dit_model(device="cuda")
|
| 246 |
+
|
| 247 |
+
self.dit.eval().to("cuda")
|
| 248 |
+
convert_linears_to_te_fp8(self.dit)
|
| 249 |
+
|
| 250 |
+
self.dit = torch.compile(self.dit, )
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
self.configure_vae_model(device="cuda")
|
| 254 |
+
if self.config.generation.get('extract_audio_feat', False):
|
| 255 |
+
self.configure_wav2vec(device="cpu")
|
| 256 |
+
self.configure_text_model(device="cuda")
|
| 257 |
+
|
| 258 |
+
# # Initialize fsdp.
|
| 259 |
+
# self.configure_dit_fsdp_model()
|
| 260 |
+
# self.configure_text_fsdp_model()
|
| 261 |
+
|
| 262 |
+
# quantize_(self.text_encoder, Int8WeightOnlyConfig())
|
| 263 |
+
# quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def configure_dit_model(self, device=get_device()):
|
| 267 |
+
|
| 268 |
+
init_unified_parallel(self.config.dit.sp_size)
|
| 269 |
+
self.sp_size = get_unified_parallel_world_size()
|
| 270 |
+
|
| 271 |
+
# Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
|
| 272 |
+
init_device = "meta"
|
| 273 |
+
with torch.device(init_device):
|
| 274 |
+
self.dit = create_object(self.config.dit.model)
|
| 275 |
+
self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
|
| 276 |
+
self.logger.info(f"Load DiT model on {init_device}.")
|
| 277 |
+
self.dit.eval().requires_grad_(False)
|
| 278 |
+
|
| 279 |
+
# Load dit checkpoint.
|
| 280 |
+
path = self.config.dit.checkpoint_dir
|
| 281 |
+
|
| 282 |
+
def _cast_state_dict_to_bf16(state):
|
| 283 |
+
for k, v in state.items():
|
| 284 |
+
if isinstance(v, torch.Tensor) and v.is_floating_point():
|
| 285 |
+
state[k] = v.to(dtype=torch.bfloat16, copy=False)
|
| 286 |
+
return state
|
| 287 |
+
|
| 288 |
+
if path.endswith(".pth"):
|
| 289 |
+
# Load to CPU first; we’ll move the model later.
|
| 290 |
+
state = torch.load(path, map_location="cpu", mmap=True)
|
| 291 |
+
state = _cast_state_dict_to_bf16(state)
|
| 292 |
+
missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
|
| 293 |
+
self.logger.info(
|
| 294 |
+
f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
from safetensors.torch import load_file
|
| 298 |
+
import json
|
| 299 |
+
def load_custom_sharded_weights(model_dir, base_name):
|
| 300 |
+
index_path = f"{model_dir}/{base_name}.safetensors.index.json"
|
| 301 |
+
with open(index_path, "r") as f:
|
| 302 |
+
index = json.load(f)
|
| 303 |
+
weight_map = index["weight_map"]
|
| 304 |
+
shard_files = set(weight_map.values())
|
| 305 |
+
state_dict = {}
|
| 306 |
+
for shard_file in shard_files:
|
| 307 |
+
shard_path = f"{model_dir}/{shard_file}"
|
| 308 |
+
# Load on CPU, then cast to bf16; we’ll move the whole module later.
|
| 309 |
+
shard_state = load_file(shard_path, device="cpu")
|
| 310 |
+
shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
|
| 311 |
+
for k, v in shard_state.items()}
|
| 312 |
+
state_dict.update(shard_state)
|
| 313 |
+
return state_dict
|
| 314 |
+
|
| 315 |
+
state = load_custom_sharded_weights(path, 'humo')
|
| 316 |
+
self.dit.load_state_dict(state, strict=False, assign=True)
|
| 317 |
+
|
| 318 |
+
self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
| 319 |
+
|
| 320 |
+
target_device = get_device() if device in [get_device(), "cuda"] else device
|
| 321 |
+
self.dit.to(target_device) # dtype already bf16
|
| 322 |
+
|
| 323 |
+
# Print model size.
|
| 324 |
+
params = sum(p.numel() for p in self.dit.parameters())
|
| 325 |
+
self.logger.info(
|
| 326 |
+
f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def configure_vae_model(self, device=get_device()):
|
| 331 |
+
self.vae_stride = self.config.vae.vae_stride
|
| 332 |
+
self.vae = WanVAE(
|
| 333 |
+
vae_pth=self.config.vae.checkpoint,
|
| 334 |
+
device=device)
|
| 335 |
+
|
| 336 |
+
if self.config.generation.height == 480:
|
| 337 |
+
self.zero_vae = torch.load(self.config.dit.zero_vae_path)
|
| 338 |
+
elif self.config.generation.height == 720:
|
| 339 |
+
self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
|
| 340 |
+
else:
|
| 341 |
+
raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
|
| 342 |
+
|
| 343 |
+
def configure_wav2vec(self, device=get_device()):
|
| 344 |
+
audio_separator_model_file = self.config.audio.vocal_separator
|
| 345 |
+
wav2vec_model_path = self.config.audio.wav2vec_model
|
| 346 |
+
|
| 347 |
+
self.audio_processor = AudioProcessor(
|
| 348 |
+
16000,
|
| 349 |
+
25,
|
| 350 |
+
wav2vec_model_path,
|
| 351 |
+
"all",
|
| 352 |
+
audio_separator_model_file,
|
| 353 |
+
None, # not seperate
|
| 354 |
+
os.path.join(self.config.generation.output.dir, "vocals"),
|
| 355 |
+
device=device,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def configure_text_model(self, device=get_device()):
|
| 359 |
+
self.text_encoder = T5EncoderModel(
|
| 360 |
+
text_len=self.config.dit.model.text_len,
|
| 361 |
+
dtype=torch.bfloat16,
|
| 362 |
+
device=device,
|
| 363 |
+
checkpoint_path=self.config.text.t5_checkpoint,
|
| 364 |
+
tokenizer_path=self.config.text.t5_tokenizer,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def configure_dit_fsdp_model(self):
|
| 369 |
+
from humo.models.wan_modules.model_humo import WanAttentionBlock
|
| 370 |
+
|
| 371 |
+
dit_blocks = (WanAttentionBlock,)
|
| 372 |
+
|
| 373 |
+
# Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
|
| 374 |
+
init_model_shard_cpu_group(
|
| 375 |
+
self.config.dit.fsdp.sharding_strategy,
|
| 376 |
+
self.config.dit.fsdp.get("device_mesh", None),
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Assert that dit has wrappable blocks.
|
| 380 |
+
assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
|
| 381 |
+
|
| 382 |
+
# Define wrap policy on all dit blocks.
|
| 383 |
+
def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
|
| 384 |
+
return recurse or isinstance(module, dit_blocks)
|
| 385 |
+
|
| 386 |
+
# Configure FSDP settings.
|
| 387 |
+
device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
|
| 388 |
+
self.config.dit.fsdp.sharding_strategy,
|
| 389 |
+
self.config.dit.fsdp.get("device_mesh", None),
|
| 390 |
+
)
|
| 391 |
+
settings = dict(
|
| 392 |
+
auto_wrap_policy=custom_auto_wrap_policy,
|
| 393 |
+
sharding_strategy=fsdp_strategy,
|
| 394 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
| 395 |
+
device_id=get_local_rank(),
|
| 396 |
+
use_orig_params=False,
|
| 397 |
+
sync_module_states=True,
|
| 398 |
+
forward_prefetch=True,
|
| 399 |
+
limit_all_gathers=False, # False for ZERO2.
|
| 400 |
+
mixed_precision=MixedPrecision(
|
| 401 |
+
param_dtype=torch.bfloat16,
|
| 402 |
+
reduce_dtype=torch.float32,
|
| 403 |
+
buffer_dtype=torch.float32,
|
| 404 |
+
),
|
| 405 |
+
device_mesh=device_mesh,
|
| 406 |
+
param_init_fn=meta_param_init_fn,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Apply FSDP.
|
| 410 |
+
self.dit = FullyShardedDataParallel(self.dit, **settings)
|
| 411 |
+
# self.dit.to(get_device())
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def configure_text_fsdp_model(self):
|
| 415 |
+
# If FSDP is not enabled, put text_encoder to GPU and return.
|
| 416 |
+
if not self.config.text.fsdp.enabled:
|
| 417 |
+
self.text_encoder.to(get_device())
|
| 418 |
+
return
|
| 419 |
+
|
| 420 |
+
# from transformers.models.t5.modeling_t5 import T5Block
|
| 421 |
+
from humo.models.wan_modules.t5 import T5SelfAttention
|
| 422 |
+
|
| 423 |
+
text_blocks = (torch.nn.Embedding, T5SelfAttention)
|
| 424 |
+
# text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
|
| 425 |
+
|
| 426 |
+
def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
|
| 427 |
+
return (
|
| 428 |
+
recurse
|
| 429 |
+
or isinstance(module, text_blocks)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Apply FSDP.
|
| 433 |
+
text_encoder_dtype = getattr(torch, self.config.text.dtype)
|
| 434 |
+
device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
|
| 435 |
+
self.config.text.fsdp.sharding_strategy,
|
| 436 |
+
self.config.text.fsdp.get("device_mesh", None),
|
| 437 |
+
)
|
| 438 |
+
self.text_encoder = FullyShardedDataParallel(
|
| 439 |
+
module=self.text_encoder,
|
| 440 |
+
auto_wrap_policy=custom_auto_wrap_policy,
|
| 441 |
+
sharding_strategy=fsdp_strategy,
|
| 442 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
| 443 |
+
device_id=get_local_rank(),
|
| 444 |
+
use_orig_params=False,
|
| 445 |
+
sync_module_states=False,
|
| 446 |
+
forward_prefetch=True,
|
| 447 |
+
limit_all_gathers=True,
|
| 448 |
+
mixed_precision=MixedPrecision(
|
| 449 |
+
param_dtype=text_encoder_dtype,
|
| 450 |
+
reduce_dtype=text_encoder_dtype,
|
| 451 |
+
buffer_dtype=text_encoder_dtype,
|
| 452 |
+
),
|
| 453 |
+
device_mesh=device_mesh,
|
| 454 |
+
)
|
| 455 |
+
self.text_encoder.to(get_device()).requires_grad_(False)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def load_image_latent_ref_id(self, path: str, size, device):
|
| 459 |
+
# Load size.
|
| 460 |
+
h, w = size[1], size[0]
|
| 461 |
+
|
| 462 |
+
# Load image.
|
| 463 |
+
if len(path) > 1 and not isinstance(path, str):
|
| 464 |
+
ref_vae_latents = []
|
| 465 |
+
for image_path in path:
|
| 466 |
+
with Image.open(image_path) as img:
|
| 467 |
+
img = img.convert("RGB")
|
| 468 |
+
|
| 469 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 470 |
+
img_ratio = img.width / img.height
|
| 471 |
+
target_ratio = w / h
|
| 472 |
+
|
| 473 |
+
if img_ratio > target_ratio: # Image is wider than target
|
| 474 |
+
new_width = w
|
| 475 |
+
new_height = int(new_width / img_ratio)
|
| 476 |
+
else: # Image is taller than target
|
| 477 |
+
new_height = h
|
| 478 |
+
new_width = int(new_height * img_ratio)
|
| 479 |
+
|
| 480 |
+
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 481 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 482 |
+
|
| 483 |
+
# Create a new image with the target size and place the resized image in the center
|
| 484 |
+
delta_w = w - img.size[0]
|
| 485 |
+
delta_h = h - img.size[1]
|
| 486 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 487 |
+
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 488 |
+
|
| 489 |
+
# Transform to tensor and normalize.
|
| 490 |
+
transform = Compose(
|
| 491 |
+
[
|
| 492 |
+
ToTensor(),
|
| 493 |
+
Normalize(0.5, 0.5),
|
| 494 |
+
]
|
| 495 |
+
)
|
| 496 |
+
new_img = transform(new_img)
|
| 497 |
+
# img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
|
| 498 |
+
img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
|
| 499 |
+
ref_vae_latents.append(img_vae_latent[0])
|
| 500 |
+
|
| 501 |
+
return [torch.cat(ref_vae_latents, dim=1)]
|
| 502 |
+
else:
|
| 503 |
+
if not isinstance(path, str):
|
| 504 |
+
path = path[0]
|
| 505 |
+
with Image.open(path) as img:
|
| 506 |
+
img = img.convert("RGB")
|
| 507 |
+
|
| 508 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 509 |
+
img_ratio = img.width / img.height
|
| 510 |
+
target_ratio = w / h
|
| 511 |
+
|
| 512 |
+
if img_ratio > target_ratio: # Image is wider than target
|
| 513 |
+
new_width = w
|
| 514 |
+
new_height = int(new_width / img_ratio)
|
| 515 |
+
else: # Image is taller than target
|
| 516 |
+
new_height = h
|
| 517 |
+
new_width = int(new_height * img_ratio)
|
| 518 |
+
|
| 519 |
+
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 520 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 521 |
+
|
| 522 |
+
# Create a new image with the target size and place the resized image in the center
|
| 523 |
+
delta_w = w - img.size[0]
|
| 524 |
+
delta_h = h - img.size[1]
|
| 525 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 526 |
+
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 527 |
+
|
| 528 |
+
# Transform to tensor and normalize.
|
| 529 |
+
transform = Compose(
|
| 530 |
+
[
|
| 531 |
+
ToTensor(),
|
| 532 |
+
Normalize(0.5, 0.5),
|
| 533 |
+
]
|
| 534 |
+
)
|
| 535 |
+
new_img = transform(new_img)
|
| 536 |
+
img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
|
| 537 |
+
|
| 538 |
+
# Vae encode.
|
| 539 |
+
return img_vae_latent
|
| 540 |
+
|
| 541 |
+
def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
|
| 542 |
+
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
|
| 543 |
+
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
|
| 544 |
+
iter_ = 1 + (frame_num - 1) // 4
|
| 545 |
+
audio_emb_wind = []
|
| 546 |
+
for lt_i in range(iter_):
|
| 547 |
+
if lt_i == 0:
|
| 548 |
+
st = frame0_idx + lt_i - 2
|
| 549 |
+
ed = frame0_idx + lt_i + 3
|
| 550 |
+
wind_feat = torch.stack([
|
| 551 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 552 |
+
for i in range(st, ed)
|
| 553 |
+
], dim=0)
|
| 554 |
+
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
|
| 555 |
+
else:
|
| 556 |
+
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
|
| 557 |
+
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
|
| 558 |
+
wind_feat = torch.stack([
|
| 559 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 560 |
+
for i in range(st, ed)
|
| 561 |
+
], dim=0)
|
| 562 |
+
audio_emb_wind.append(wind_feat)
|
| 563 |
+
audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
|
| 564 |
+
|
| 565 |
+
return audio_emb_wind, ed - audio_shift
|
| 566 |
+
|
| 567 |
+
def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
|
| 568 |
+
if wav_enc_type == "wav2vec":
|
| 569 |
+
feat_merge = audio_emb
|
| 570 |
+
elif wav_enc_type == "whisper":
|
| 571 |
+
feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
|
| 572 |
+
feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
|
| 573 |
+
feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
|
| 574 |
+
feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
|
| 575 |
+
feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
|
| 576 |
+
feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
|
| 577 |
+
else:
|
| 578 |
+
raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
|
| 579 |
+
|
| 580 |
+
return feat_merge
|
| 581 |
+
|
| 582 |
+
def parse_output(self, output):
|
| 583 |
+
latent = output[0]
|
| 584 |
+
mask = None
|
| 585 |
+
return latent, mask
|
| 586 |
+
|
| 587 |
+
def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
|
| 588 |
+
pos_tia, _ = self.parse_output(self.dit(
|
| 589 |
+
latents, t=timestep, **arg_tia
|
| 590 |
+
))
|
| 591 |
+
torch.cuda.empty_cache()
|
| 592 |
+
|
| 593 |
+
pos_ti, _ = self.parse_output(self.dit(
|
| 594 |
+
latents, t=timestep, **arg_ti
|
| 595 |
+
))
|
| 596 |
+
torch.cuda.empty_cache()
|
| 597 |
+
|
| 598 |
+
if t > step_change:
|
| 599 |
+
neg, _ = self.parse_output(self.dit(
|
| 600 |
+
latents, t=timestep, **arg_i
|
| 601 |
+
)) # img included in null, same with official Wan-2.1
|
| 602 |
+
torch.cuda.empty_cache()
|
| 603 |
+
|
| 604 |
+
noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
|
| 605 |
+
self.config.generation.scale_t * (pos_ti - neg) + \
|
| 606 |
+
neg
|
| 607 |
+
else:
|
| 608 |
+
neg, _ = self.parse_output(self.dit(
|
| 609 |
+
latents, t=timestep, **arg_null
|
| 610 |
+
)) # img not included in null
|
| 611 |
+
torch.cuda.empty_cache()
|
| 612 |
+
|
| 613 |
+
noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
|
| 614 |
+
(self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
|
| 615 |
+
neg
|
| 616 |
+
return noise_pred
|
| 617 |
+
|
| 618 |
+
def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
|
| 619 |
+
# Positive with text+image (no audio)
|
| 620 |
+
pos_ti, _ = self.parse_output(self.dit(
|
| 621 |
+
latents, t=timestep, **arg_ti
|
| 622 |
+
))
|
| 623 |
+
torch.cuda.empty_cache()
|
| 624 |
+
|
| 625 |
+
# Positive with text only (no image, no audio)
|
| 626 |
+
pos_t, _ = self.parse_output(self.dit(
|
| 627 |
+
latents, t=timestep, **arg_t
|
| 628 |
+
))
|
| 629 |
+
torch.cuda.empty_cache()
|
| 630 |
+
|
| 631 |
+
# Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
|
| 632 |
+
if t > step_change:
|
| 633 |
+
neg, _ = self.parse_output(self.dit(
|
| 634 |
+
latents, t=timestep, **arg_i
|
| 635 |
+
)) # img included in null
|
| 636 |
+
else:
|
| 637 |
+
neg, _ = self.parse_output(self.dit(
|
| 638 |
+
latents, t=timestep, **arg_null
|
| 639 |
+
)) # img NOT included in null
|
| 640 |
+
torch.cuda.empty_cache()
|
| 641 |
+
|
| 642 |
+
# Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
|
| 643 |
+
noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
|
| 644 |
+
self.config.generation.scale_t * (pos_t - neg) + \
|
| 645 |
+
neg
|
| 646 |
+
return noise_pred
|
| 647 |
+
|
| 648 |
+
def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
|
| 649 |
+
pos_ta, _ = self.parse_output(self.dit(
|
| 650 |
+
latents, t=timestep, **arg_ta
|
| 651 |
+
))
|
| 652 |
+
torch.cuda.empty_cache()
|
| 653 |
+
|
| 654 |
+
pos_t, _ = self.parse_output(self.dit(
|
| 655 |
+
latents, t=timestep, **arg_t
|
| 656 |
+
))
|
| 657 |
+
torch.cuda.empty_cache()
|
| 658 |
+
|
| 659 |
+
neg, _ = self.parse_output(self.dit(
|
| 660 |
+
latents, t=timestep, **arg_null
|
| 661 |
+
))
|
| 662 |
+
torch.cuda.empty_cache()
|
| 663 |
+
|
| 664 |
+
noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
|
| 665 |
+
self.config.generation.scale_t * (pos_t - neg) + \
|
| 666 |
+
neg
|
| 667 |
+
return noise_pred
|
| 668 |
+
|
| 669 |
+
@torch.no_grad()
|
| 670 |
+
def inference(self,
|
| 671 |
+
input_prompt,
|
| 672 |
+
img_path,
|
| 673 |
+
audio_path,
|
| 674 |
+
size=(1280, 720),
|
| 675 |
+
frame_num=81,
|
| 676 |
+
shift=5.0,
|
| 677 |
+
sample_solver='unipc',
|
| 678 |
+
inference_mode='TIA',
|
| 679 |
+
sampling_steps=50,
|
| 680 |
+
n_prompt="",
|
| 681 |
+
seed=-1,
|
| 682 |
+
tea_cache_l1_thresh = 0.0,
|
| 683 |
+
device = get_device(),
|
| 684 |
+
):
|
| 685 |
+
|
| 686 |
+
print("inference started")
|
| 687 |
+
|
| 688 |
+
# self.vae.model.to(device=device)
|
| 689 |
+
if img_path is not None:
|
| 690 |
+
latents_ref = self.load_image_latent_ref_id(img_path, size, device)
|
| 691 |
+
else:
|
| 692 |
+
latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
|
| 693 |
+
|
| 694 |
+
# self.vae.model.to(device="cpu")
|
| 695 |
+
|
| 696 |
+
print("vae finished")
|
| 697 |
+
|
| 698 |
+
latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
|
| 699 |
+
|
| 700 |
+
# audio
|
| 701 |
+
if audio_path is not None:
|
| 702 |
+
if self.config.generation.extract_audio_feat:
|
| 703 |
+
self.audio_processor.whisper.to(device=device)
|
| 704 |
+
audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
|
| 705 |
+
self.audio_processor.whisper.to(device='cpu')
|
| 706 |
+
else:
|
| 707 |
+
audio_emb_path = audio_path.replace(".wav", ".pt")
|
| 708 |
+
audio_emb = torch.load(audio_emb_path).to(device=device)
|
| 709 |
+
audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
|
| 710 |
+
self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
|
| 711 |
+
else:
|
| 712 |
+
audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
|
| 713 |
+
|
| 714 |
+
frame_num = frame_num if frame_num != -1 else audio_length
|
| 715 |
+
frame_num = 4 * ((frame_num - 1) // 4) + 1
|
| 716 |
+
audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
|
| 717 |
+
zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
|
| 718 |
+
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
|
| 719 |
+
audio_emb = [audio_emb.to(device)]
|
| 720 |
+
audio_emb_neg = [torch.zeros_like(audio_emb[0])]
|
| 721 |
+
|
| 722 |
+
# preprocess
|
| 723 |
+
self.patch_size = self.config.dit.model.patch_size
|
| 724 |
+
F = frame_num
|
| 725 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
|
| 726 |
+
size[1] // self.vae_stride[1],
|
| 727 |
+
size[0] // self.vae_stride[2])
|
| 728 |
+
|
| 729 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 730 |
+
(self.patch_size[1] * self.patch_size[2]) *
|
| 731 |
+
target_shape[1] / self.sp_size) * self.sp_size
|
| 732 |
+
|
| 733 |
+
if n_prompt == "":
|
| 734 |
+
n_prompt = self.config.generation.sample_neg_prompt
|
| 735 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 736 |
+
seed_g = torch.Generator(device=device)
|
| 737 |
+
seed_g.manual_seed(seed)
|
| 738 |
+
|
| 739 |
+
# self.text_encoder.model.to(device)
|
| 740 |
+
context = self.text_encoder([input_prompt], device)
|
| 741 |
+
context_null = self.text_encoder([n_prompt], device)
|
| 742 |
+
# self.text_encoder.model.cpu()
|
| 743 |
+
|
| 744 |
+
print("text encoder finished")
|
| 745 |
+
|
| 746 |
+
noise = [
|
| 747 |
+
torch.randn(
|
| 748 |
+
target_shape[0],
|
| 749 |
+
target_shape[1], # - latents_ref[0].shape[1],
|
| 750 |
+
target_shape[2],
|
| 751 |
+
target_shape[3],
|
| 752 |
+
dtype=torch.float32,
|
| 753 |
+
device=device,
|
| 754 |
+
generator=seed_g)
|
| 755 |
+
]
|
| 756 |
+
|
| 757 |
+
@contextmanager
|
| 758 |
+
def noop_no_sync():
|
| 759 |
+
yield
|
| 760 |
+
|
| 761 |
+
no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
|
| 762 |
+
step_change = self.config.generation.step_change # 980
|
| 763 |
+
|
| 764 |
+
# evaluation mode
|
| 765 |
+
with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
|
| 766 |
+
|
| 767 |
+
if sample_solver == 'unipc':
|
| 768 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 769 |
+
num_train_timesteps=1000,
|
| 770 |
+
shift=1,
|
| 771 |
+
use_dynamic_shifting=False)
|
| 772 |
+
sample_scheduler.set_timesteps(
|
| 773 |
+
sampling_steps, device=device, shift=shift)
|
| 774 |
+
timesteps = sample_scheduler.timesteps
|
| 775 |
+
|
| 776 |
+
# sample videos
|
| 777 |
+
latents = noise
|
| 778 |
+
|
| 779 |
+
msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
|
| 780 |
+
msk[:,:-latents_ref[0].shape[1]] = 0
|
| 781 |
+
|
| 782 |
+
zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
|
| 783 |
+
device=get_device(), dtype=latents_ref[0].dtype)
|
| 784 |
+
y_c = torch.cat([
|
| 785 |
+
zero_vae,
|
| 786 |
+
latents_ref[0]
|
| 787 |
+
], dim=1)
|
| 788 |
+
y_c = [torch.concat([msk, y_c])]
|
| 789 |
+
|
| 790 |
+
y_null = self.zero_vae[:, :target_shape[1]].to(
|
| 791 |
+
device=get_device(), dtype=latents_ref[0].dtype)
|
| 792 |
+
y_null = [torch.concat([msk, y_null])]
|
| 793 |
+
|
| 794 |
+
tea_cache_l1_thresh = tea_cache_l1_thresh
|
| 795 |
+
tea_cache_model_id = "Wan2.1-T2V-14B"
|
| 796 |
+
|
| 797 |
+
arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 798 |
+
arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 799 |
+
arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 800 |
+
arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 801 |
+
arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 802 |
+
arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
|
| 803 |
+
|
| 804 |
+
torch.cuda.empty_cache()
|
| 805 |
+
# self.dit.to(device=get_device())
|
| 806 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 807 |
+
timestep = [t]
|
| 808 |
+
timestep = torch.stack(timestep)
|
| 809 |
+
|
| 810 |
+
if inference_mode == "TIA":
|
| 811 |
+
noise_pred = self.forward_tia(latents, timestep, t, step_change,
|
| 812 |
+
arg_tia, arg_ti, arg_i, arg_null)
|
| 813 |
+
elif inference_mode == "TA":
|
| 814 |
+
noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
|
| 815 |
+
elif inference_mode == "TI":
|
| 816 |
+
noise_pred = self.forward_ti(latents, timestep, t, step_change,
|
| 817 |
+
arg_ti, arg_t, arg_i, arg_null)
|
| 818 |
+
else:
|
| 819 |
+
raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
|
| 820 |
+
|
| 821 |
+
temp_x0 = sample_scheduler.step(
|
| 822 |
+
noise_pred.unsqueeze(0),
|
| 823 |
+
t,
|
| 824 |
+
latents[0].unsqueeze(0),
|
| 825 |
+
return_dict=False,
|
| 826 |
+
generator=seed_g)[0]
|
| 827 |
+
latents = [temp_x0.squeeze(0)]
|
| 828 |
+
|
| 829 |
+
del timestep
|
| 830 |
+
torch.cuda.empty_cache()
|
| 831 |
+
|
| 832 |
+
x0 = latents
|
| 833 |
+
x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
|
| 834 |
+
|
| 835 |
+
# if offload_model:
|
| 836 |
+
# self.dit.cpu()
|
| 837 |
+
|
| 838 |
+
print("dit finished")
|
| 839 |
+
|
| 840 |
+
torch.cuda.empty_cache()
|
| 841 |
+
# if get_local_rank() == 0:
|
| 842 |
+
# self.vae.model.to(device=device)
|
| 843 |
+
videos = self.vae.decode(x0)
|
| 844 |
+
# self.vae.model.to(device="cpu")
|
| 845 |
+
|
| 846 |
+
print("vae 2 finished")
|
| 847 |
+
|
| 848 |
+
del noise, latents, noise_pred
|
| 849 |
+
del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
|
| 850 |
+
del x0, temp_x0
|
| 851 |
+
del sample_scheduler
|
| 852 |
+
torch.cuda.empty_cache()
|
| 853 |
+
gc.collect()
|
| 854 |
+
torch.cuda.synchronize()
|
| 855 |
+
if dist.is_initialized():
|
| 856 |
+
dist.barrier()
|
| 857 |
+
|
| 858 |
+
return videos[0] # if get_local_rank() == 0 else None
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
|
| 862 |
+
|
| 863 |
+
video = self.inference(
|
| 864 |
+
prompt,
|
| 865 |
+
ref_img_path,
|
| 866 |
+
audio_path,
|
| 867 |
+
size=SIZE_CONFIGS[f"{width}*{height}"],
|
| 868 |
+
frame_num=frames,
|
| 869 |
+
shift=self.config.diffusion.timesteps.sampling.shift,
|
| 870 |
+
sample_solver='unipc',
|
| 871 |
+
sampling_steps=steps,
|
| 872 |
+
inference_mode = inference_mode,
|
| 873 |
+
tea_cache_l1_thresh = tea_cache_l1_thresh,
|
| 874 |
+
seed=seed
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
torch.cuda.empty_cache()
|
| 878 |
+
gc.collect()
|
| 879 |
+
|
| 880 |
+
# Save samples.
|
| 881 |
+
if get_sequence_parallel_rank() == 0:
|
| 882 |
+
pathname = self.save_sample(
|
| 883 |
+
sample=video,
|
| 884 |
+
audio_path=audio_path,
|
| 885 |
+
output_dir = output_dir,
|
| 886 |
+
filename=filename,
|
| 887 |
+
)
|
| 888 |
+
self.logger.info(f"Finished {filename}, saved to {pathname}.")
|
| 889 |
+
|
| 890 |
+
del video, prompt
|
| 891 |
+
torch.cuda.empty_cache()
|
| 892 |
+
gc.collect()
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
|
| 896 |
+
gen_config = self.config.generation
|
| 897 |
+
# Prepare file path.
|
| 898 |
+
extension = ".mp4" if sample.ndim == 4 else ".png"
|
| 899 |
+
filename += extension
|
| 900 |
+
pathname = os.path.join(output_dir, filename)
|
| 901 |
+
# Convert sample.
|
| 902 |
+
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
|
| 903 |
+
sample = rearrange(sample, "c t h w -> t h w c")
|
| 904 |
+
# Save file.
|
| 905 |
+
if sample.ndim == 4:
|
| 906 |
+
if audio_path is not None:
|
| 907 |
+
tensor_to_video(
|
| 908 |
+
sample.numpy(),
|
| 909 |
+
pathname,
|
| 910 |
+
audio_path,
|
| 911 |
+
fps=gen_config.fps)
|
| 912 |
+
else:
|
| 913 |
+
mediapy.write_video(
|
| 914 |
+
path=pathname,
|
| 915 |
+
images=sample.numpy(),
|
| 916 |
+
fps=gen_config.fps,
|
| 917 |
+
)
|
| 918 |
+
else:
|
| 919 |
+
raise ValueError
|
| 920 |
+
return pathname
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def prepare_positive_prompts(self):
|
| 924 |
+
pos_prompts = self.config.generation.positive_prompt
|
| 925 |
+
if pos_prompts.endswith(".json"):
|
| 926 |
+
pos_prompts = prepare_json_dataset(pos_prompts)
|
| 927 |
+
else:
|
| 928 |
+
raise NotImplementedError
|
| 929 |
+
assert isinstance(pos_prompts, ListConfig)
|
| 930 |
+
|
| 931 |
+
return pos_prompts
|
| 932 |
+
|
| 933 |
+
class TeaCache:
|
| 934 |
+
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
| 935 |
+
self.num_inference_steps = num_inference_steps
|
| 936 |
+
self.step = 0
|
| 937 |
+
self.accumulated_rel_l1_distance = 0
|
| 938 |
+
self.previous_modulated_input = None
|
| 939 |
+
self.rel_l1_thresh = rel_l1_thresh
|
| 940 |
+
self.previous_residual = None
|
| 941 |
+
self.previous_hidden_states = None
|
| 942 |
+
|
| 943 |
+
self.coefficients_dict = {
|
| 944 |
+
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
| 945 |
+
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
| 946 |
+
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
| 947 |
+
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
| 948 |
+
}
|
| 949 |
+
if model_id not in self.coefficients_dict:
|
| 950 |
+
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
| 951 |
+
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
| 952 |
+
self.coefficients = self.coefficients_dict[model_id]
|
| 953 |
+
|
| 954 |
+
def check(self, dit, x, t_mod):
|
| 955 |
+
modulated_inp = t_mod.clone()
|
| 956 |
+
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
| 957 |
+
should_calc = True
|
| 958 |
+
self.accumulated_rel_l1_distance = 0
|
| 959 |
+
else:
|
| 960 |
+
coefficients = self.coefficients
|
| 961 |
+
rescale_func = np.poly1d(coefficients)
|
| 962 |
+
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| 963 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
| 964 |
+
should_calc = False
|
| 965 |
+
else:
|
| 966 |
+
should_calc = True
|
| 967 |
+
self.accumulated_rel_l1_distance = 0
|
| 968 |
+
self.previous_modulated_input = modulated_inp
|
| 969 |
+
self.step += 1
|
| 970 |
+
if self.step == self.num_inference_steps:
|
| 971 |
+
self.step = 0
|
| 972 |
+
if should_calc:
|
| 973 |
+
self.previous_hidden_states = x.clone()
|
| 974 |
+
return not should_calc
|
| 975 |
+
|
| 976 |
+
def store(self, hidden_states):
|
| 977 |
+
if self.previous_hidden_states is None:
|
| 978 |
+
return
|
| 979 |
+
self.previous_residual = hidden_states - self.previous_hidden_states
|
| 980 |
+
self.previous_hidden_states = None
|
| 981 |
+
|
| 982 |
+
def update(self, hidden_states):
|
| 983 |
+
hidden_states = hidden_states + self.previous_residual
|
| 984 |
+
return hidden_states
|
humo/generate_1_7B.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Inference codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import os
|
| 17 |
+
import gc
|
| 18 |
+
import random
|
| 19 |
+
import sys
|
| 20 |
+
import mediapy
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 24 |
+
from einops import rearrange
|
| 25 |
+
from omegaconf import OmegaConf
|
| 26 |
+
from PIL import Image, ImageOps
|
| 27 |
+
from torchvision.transforms import ToTensor
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
from torch.distributed.device_mesh import init_device_mesh
|
| 30 |
+
from torch.distributed.fsdp import (
|
| 31 |
+
BackwardPrefetch,
|
| 32 |
+
FullyShardedDataParallel,
|
| 33 |
+
MixedPrecision,
|
| 34 |
+
ShardingStrategy,
|
| 35 |
+
)
|
| 36 |
+
from common.distributed import (
|
| 37 |
+
get_device,
|
| 38 |
+
get_global_rank,
|
| 39 |
+
get_local_rank,
|
| 40 |
+
meta_param_init_fn,
|
| 41 |
+
meta_non_persistent_buffer_init_fn,
|
| 42 |
+
init_torch,
|
| 43 |
+
)
|
| 44 |
+
from common.distributed.advanced import (
|
| 45 |
+
init_unified_parallel,
|
| 46 |
+
get_unified_parallel_world_size,
|
| 47 |
+
get_sequence_parallel_rank,
|
| 48 |
+
init_model_shard_cpu_group,
|
| 49 |
+
)
|
| 50 |
+
from common.logger import get_logger
|
| 51 |
+
from common.config import create_object
|
| 52 |
+
from common.distributed import get_device, get_global_rank
|
| 53 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
| 54 |
+
from humo.models.wan_modules.t5 import T5EncoderModel
|
| 55 |
+
from humo.models.wan_modules.vae import WanVAE
|
| 56 |
+
from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
|
| 57 |
+
from contextlib import contextmanager
|
| 58 |
+
import torch.cuda.amp as amp
|
| 59 |
+
from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 60 |
+
from humo.utils.audio_processor_whisper import AudioProcessor
|
| 61 |
+
from humo.utils.wav2vec import linear_interpolation_fps
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
image_transform = Compose([
|
| 65 |
+
ToTensor(),
|
| 66 |
+
Normalize(mean=0.5, std=0.5),
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
SIZE_CONFIGS = {
|
| 70 |
+
'720*1280': (720, 1280),
|
| 71 |
+
'1280*720': (1280, 720),
|
| 72 |
+
'480*832': (480, 832),
|
| 73 |
+
'832*480': (832, 480),
|
| 74 |
+
'1024*1024': (1024, 1024),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def clever_format(nums, format="%.2f"):
|
| 78 |
+
from typing import Iterable
|
| 79 |
+
if not isinstance(nums, Iterable):
|
| 80 |
+
nums = [nums]
|
| 81 |
+
clever_nums = []
|
| 82 |
+
for num in nums:
|
| 83 |
+
if num > 1e12:
|
| 84 |
+
clever_nums.append(format % (num / 1e12) + "T")
|
| 85 |
+
elif num > 1e9:
|
| 86 |
+
clever_nums.append(format % (num / 1e9) + "G")
|
| 87 |
+
elif num > 1e6:
|
| 88 |
+
clever_nums.append(format % (num / 1e6) + "M")
|
| 89 |
+
elif num > 1e3:
|
| 90 |
+
clever_nums.append(format % (num / 1e3) + "K")
|
| 91 |
+
else:
|
| 92 |
+
clever_nums.append(format % num + "B")
|
| 93 |
+
|
| 94 |
+
clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
|
| 95 |
+
|
| 96 |
+
return clever_nums
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Generator():
|
| 100 |
+
def __init__(self, config: DictConfig):
|
| 101 |
+
self.config = config.copy()
|
| 102 |
+
OmegaConf.set_readonly(self.config, True)
|
| 103 |
+
self.logger = get_logger(self.__class__.__name__)
|
| 104 |
+
self.configure_models()
|
| 105 |
+
|
| 106 |
+
# init_torch(cudnn_benchmark=False)
|
| 107 |
+
|
| 108 |
+
def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
|
| 109 |
+
device_mesh = None
|
| 110 |
+
fsdp_strategy = ShardingStrategy[sharding_strategy]
|
| 111 |
+
if (
|
| 112 |
+
fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
|
| 113 |
+
and device_mesh_config is not None
|
| 114 |
+
):
|
| 115 |
+
device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
|
| 116 |
+
return device_mesh, fsdp_strategy
|
| 117 |
+
|
| 118 |
+
def configure_models(self):
|
| 119 |
+
self.configure_dit_model(device="cpu")
|
| 120 |
+
self.configure_vae_model()
|
| 121 |
+
if self.config.generation.get('extract_audio_feat', False):
|
| 122 |
+
self.configure_wav2vec(device="cpu")
|
| 123 |
+
self.configure_text_model(device="cpu")
|
| 124 |
+
|
| 125 |
+
# Initialize fsdp.
|
| 126 |
+
self.configure_dit_fsdp_model()
|
| 127 |
+
self.configure_text_fsdp_model()
|
| 128 |
+
|
| 129 |
+
def configure_dit_model(self, device=get_device()):
|
| 130 |
+
|
| 131 |
+
init_unified_parallel(self.config.dit.sp_size)
|
| 132 |
+
self.sp_size = get_unified_parallel_world_size()
|
| 133 |
+
|
| 134 |
+
# Create dit model.
|
| 135 |
+
init_device = "meta"
|
| 136 |
+
with torch.device(init_device):
|
| 137 |
+
self.dit = create_object(self.config.dit.model)
|
| 138 |
+
self.logger.info(f"Load DiT model on {init_device}.")
|
| 139 |
+
self.dit.eval().requires_grad_(False)
|
| 140 |
+
|
| 141 |
+
# Load dit checkpoint.
|
| 142 |
+
path = self.config.dit.checkpoint_dir
|
| 143 |
+
if path.endswith(".pth"):
|
| 144 |
+
state = torch.load(path, map_location=device, mmap=True)
|
| 145 |
+
missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
|
| 146 |
+
self.logger.info(
|
| 147 |
+
f"dit loaded from {path}. "
|
| 148 |
+
f"Missing keys: {len(missing_keys)}, "
|
| 149 |
+
f"Unexpected keys: {len(unexpected_keys)}"
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
from safetensors.torch import load_file
|
| 153 |
+
import json
|
| 154 |
+
def load_custom_sharded_weights(model_dir, base_name, device=device):
|
| 155 |
+
index_path = f"{model_dir}/{base_name}.safetensors.index.json"
|
| 156 |
+
with open(index_path, "r") as f:
|
| 157 |
+
index = json.load(f)
|
| 158 |
+
weight_map = index["weight_map"]
|
| 159 |
+
shard_files = set(weight_map.values())
|
| 160 |
+
state_dict = {}
|
| 161 |
+
for shard_file in shard_files:
|
| 162 |
+
shard_path = f"{model_dir}/{shard_file}"
|
| 163 |
+
shard_state = load_file(shard_path)
|
| 164 |
+
shard_state = {k: v.to(device) for k, v in shard_state.items()}
|
| 165 |
+
state_dict.update(shard_state)
|
| 166 |
+
return state_dict
|
| 167 |
+
state = load_custom_sharded_weights(path, 'humo', device)
|
| 168 |
+
self.dit.load_state_dict(state, strict=False, assign=True)
|
| 169 |
+
|
| 170 |
+
self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
| 171 |
+
if device in [get_device(), "cuda"]:
|
| 172 |
+
self.dit.to(get_device())
|
| 173 |
+
|
| 174 |
+
# Print model size.
|
| 175 |
+
params = sum(p.numel() for p in self.dit.parameters())
|
| 176 |
+
self.logger.info(
|
| 177 |
+
f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def configure_vae_model(self, device=get_device()):
|
| 181 |
+
self.vae_stride = self.config.vae.vae_stride
|
| 182 |
+
self.vae = WanVAE(
|
| 183 |
+
vae_pth=self.config.vae.checkpoint,
|
| 184 |
+
device=device)
|
| 185 |
+
|
| 186 |
+
if self.config.generation.height == 480:
|
| 187 |
+
self.zero_vae = torch.load(self.config.dit.zero_vae_path)
|
| 188 |
+
elif self.config.generation.height == 720:
|
| 189 |
+
self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
|
| 192 |
+
|
| 193 |
+
def configure_wav2vec(self, device=get_device()):
|
| 194 |
+
audio_separator_model_file = self.config.audio.vocal_separator
|
| 195 |
+
wav2vec_model_path = self.config.audio.wav2vec_model
|
| 196 |
+
|
| 197 |
+
self.audio_processor = AudioProcessor(
|
| 198 |
+
16000,
|
| 199 |
+
25,
|
| 200 |
+
wav2vec_model_path,
|
| 201 |
+
"all",
|
| 202 |
+
audio_separator_model_file,
|
| 203 |
+
None, # not seperate
|
| 204 |
+
os.path.join(self.config.generation.output.dir, "vocals"),
|
| 205 |
+
device=device,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def configure_text_model(self, device=get_device()):
|
| 209 |
+
self.text_encoder = T5EncoderModel(
|
| 210 |
+
text_len=self.config.dit.model.text_len,
|
| 211 |
+
dtype=torch.bfloat16,
|
| 212 |
+
device=device,
|
| 213 |
+
checkpoint_path=self.config.text.t5_checkpoint,
|
| 214 |
+
tokenizer_path=self.config.text.t5_tokenizer,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def configure_dit_fsdp_model(self):
|
| 219 |
+
self.dit.to(get_device())
|
| 220 |
+
|
| 221 |
+
return
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def configure_text_fsdp_model(self):
|
| 225 |
+
self.text_encoder.to(get_device())
|
| 226 |
+
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def load_image_latent_ref_id(self, path: str, size, device):
|
| 231 |
+
# Load size.
|
| 232 |
+
h, w = size[1], size[0]
|
| 233 |
+
|
| 234 |
+
# Load image.
|
| 235 |
+
if len(path) > 1 and not isinstance(path, str):
|
| 236 |
+
ref_vae_latents = []
|
| 237 |
+
for image_path in path:
|
| 238 |
+
with Image.open(image_path) as img:
|
| 239 |
+
img = img.convert("RGB")
|
| 240 |
+
|
| 241 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 242 |
+
img_ratio = img.width / img.height
|
| 243 |
+
target_ratio = w / h
|
| 244 |
+
|
| 245 |
+
if img_ratio > target_ratio: # Image is wider than target
|
| 246 |
+
new_width = w
|
| 247 |
+
new_height = int(new_width / img_ratio)
|
| 248 |
+
else: # Image is taller than target
|
| 249 |
+
new_height = h
|
| 250 |
+
new_width = int(new_height * img_ratio)
|
| 251 |
+
|
| 252 |
+
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 253 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 254 |
+
|
| 255 |
+
# Create a new image with the target size and place the resized image in the center
|
| 256 |
+
delta_w = w - img.size[0]
|
| 257 |
+
delta_h = h - img.size[1]
|
| 258 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 259 |
+
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 260 |
+
|
| 261 |
+
# Transform to tensor and normalize.
|
| 262 |
+
transform = Compose(
|
| 263 |
+
[
|
| 264 |
+
ToTensor(),
|
| 265 |
+
Normalize(0.5, 0.5),
|
| 266 |
+
]
|
| 267 |
+
)
|
| 268 |
+
new_img = transform(new_img)
|
| 269 |
+
# img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
|
| 270 |
+
img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
|
| 271 |
+
ref_vae_latents.append(img_vae_latent[0])
|
| 272 |
+
|
| 273 |
+
return [torch.cat(ref_vae_latents, dim=1)]
|
| 274 |
+
else:
|
| 275 |
+
if not isinstance(path, str):
|
| 276 |
+
path = path[0]
|
| 277 |
+
with Image.open(path) as img:
|
| 278 |
+
img = img.convert("RGB")
|
| 279 |
+
|
| 280 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 281 |
+
img_ratio = img.width / img.height
|
| 282 |
+
target_ratio = w / h
|
| 283 |
+
|
| 284 |
+
if img_ratio > target_ratio: # Image is wider than target
|
| 285 |
+
new_width = w
|
| 286 |
+
new_height = int(new_width / img_ratio)
|
| 287 |
+
else: # Image is taller than target
|
| 288 |
+
new_height = h
|
| 289 |
+
new_width = int(new_height * img_ratio)
|
| 290 |
+
|
| 291 |
+
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 292 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 293 |
+
|
| 294 |
+
# Create a new image with the target size and place the resized image in the center
|
| 295 |
+
delta_w = w - img.size[0]
|
| 296 |
+
delta_h = h - img.size[1]
|
| 297 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 298 |
+
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 299 |
+
|
| 300 |
+
# Transform to tensor and normalize.
|
| 301 |
+
transform = Compose(
|
| 302 |
+
[
|
| 303 |
+
ToTensor(),
|
| 304 |
+
Normalize(0.5, 0.5),
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
+
new_img = transform(new_img)
|
| 308 |
+
img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
|
| 309 |
+
|
| 310 |
+
# Vae encode.
|
| 311 |
+
return img_vae_latent
|
| 312 |
+
|
| 313 |
+
def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
|
| 314 |
+
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
|
| 315 |
+
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
|
| 316 |
+
iter_ = 1 + (frame_num - 1) // 4
|
| 317 |
+
audio_emb_wind = []
|
| 318 |
+
for lt_i in range(iter_):
|
| 319 |
+
if lt_i == 0:
|
| 320 |
+
st = frame0_idx + lt_i - 2
|
| 321 |
+
ed = frame0_idx + lt_i + 3
|
| 322 |
+
wind_feat = torch.stack([
|
| 323 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 324 |
+
for i in range(st, ed)
|
| 325 |
+
], dim=0)
|
| 326 |
+
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
|
| 327 |
+
else:
|
| 328 |
+
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
|
| 329 |
+
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
|
| 330 |
+
wind_feat = torch.stack([
|
| 331 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 332 |
+
for i in range(st, ed)
|
| 333 |
+
], dim=0)
|
| 334 |
+
audio_emb_wind.append(wind_feat)
|
| 335 |
+
audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
|
| 336 |
+
|
| 337 |
+
return audio_emb_wind, ed - audio_shift
|
| 338 |
+
|
| 339 |
+
def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
|
| 340 |
+
if wav_enc_type == "wav2vec":
|
| 341 |
+
feat_merge = audio_emb
|
| 342 |
+
elif wav_enc_type == "whisper":
|
| 343 |
+
feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
|
| 344 |
+
feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
|
| 345 |
+
feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
|
| 346 |
+
feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
|
| 347 |
+
feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
|
| 348 |
+
feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
|
| 349 |
+
else:
|
| 350 |
+
raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
|
| 351 |
+
|
| 352 |
+
return feat_merge
|
| 353 |
+
|
| 354 |
+
def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
|
| 355 |
+
neg = self.dit(
|
| 356 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
|
| 357 |
+
)[0]
|
| 358 |
+
|
| 359 |
+
pos_t = self.dit(
|
| 360 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
|
| 361 |
+
)[0]
|
| 362 |
+
pos_ta = self.dit(
|
| 363 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
|
| 364 |
+
)[0]
|
| 365 |
+
pos_tia = self.dit(
|
| 366 |
+
[torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta
|
| 367 |
+
)[0]
|
| 368 |
+
|
| 369 |
+
noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \
|
| 370 |
+
self.config.generation.scale_a * (pos_ta - pos_t) + \
|
| 371 |
+
self.config.generation.scale_t * (pos_t - neg) + \
|
| 372 |
+
neg
|
| 373 |
+
|
| 374 |
+
return noise_pred
|
| 375 |
+
|
| 376 |
+
def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null):
|
| 377 |
+
neg = self.dit(
|
| 378 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null
|
| 379 |
+
)[0]
|
| 380 |
+
|
| 381 |
+
pos_t = self.dit(
|
| 382 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t
|
| 383 |
+
)[0]
|
| 384 |
+
pos_ta = self.dit(
|
| 385 |
+
[torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta
|
| 386 |
+
)[0]
|
| 387 |
+
|
| 388 |
+
noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
|
| 389 |
+
self.config.generation.scale_t * (pos_t - neg) + \
|
| 390 |
+
neg
|
| 391 |
+
|
| 392 |
+
return noise_pred
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@torch.no_grad()
|
| 396 |
+
def inference(self,
|
| 397 |
+
input_prompt,
|
| 398 |
+
img_path,
|
| 399 |
+
audio_path,
|
| 400 |
+
size=(1280, 720),
|
| 401 |
+
frame_num=81,
|
| 402 |
+
shift=5.0,
|
| 403 |
+
sample_solver='unipc',
|
| 404 |
+
sampling_steps=50,
|
| 405 |
+
n_prompt="",
|
| 406 |
+
seed=-1,
|
| 407 |
+
offload_model=True,
|
| 408 |
+
device = get_device(),
|
| 409 |
+
):
|
| 410 |
+
|
| 411 |
+
self.vae.model.to(device=device)
|
| 412 |
+
if img_path is not None:
|
| 413 |
+
latents_ref = self.load_image_latent_ref_id(img_path, size, device)
|
| 414 |
+
else:
|
| 415 |
+
latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
|
| 416 |
+
|
| 417 |
+
self.vae.model.to(device="cpu")
|
| 418 |
+
latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
|
| 419 |
+
|
| 420 |
+
# audio
|
| 421 |
+
if audio_path is not None:
|
| 422 |
+
if self.config.generation.extract_audio_feat:
|
| 423 |
+
self.audio_processor.whisper.to(device=device)
|
| 424 |
+
audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
|
| 425 |
+
self.audio_processor.whisper.to(device='cpu')
|
| 426 |
+
else:
|
| 427 |
+
audio_emb_path = audio_path.replace(".wav", ".pt")
|
| 428 |
+
audio_emb = torch.load(audio_emb_path).to(device=device)
|
| 429 |
+
audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
|
| 430 |
+
self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
|
| 431 |
+
else:
|
| 432 |
+
audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
|
| 433 |
+
|
| 434 |
+
frame_num = frame_num if frame_num != -1 else audio_length
|
| 435 |
+
frame_num = 4 * ((frame_num - 1) // 4) + 1
|
| 436 |
+
audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
|
| 437 |
+
zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
|
| 438 |
+
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
|
| 439 |
+
audio_emb = [audio_emb.to(device)]
|
| 440 |
+
audio_emb_neg = [torch.zeros_like(audio_emb[0])]
|
| 441 |
+
|
| 442 |
+
# preprocess
|
| 443 |
+
self.patch_size = self.config.dit.model.patch_size
|
| 444 |
+
F = frame_num
|
| 445 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
|
| 446 |
+
size[1] // self.vae_stride[1],
|
| 447 |
+
size[0] // self.vae_stride[2])
|
| 448 |
+
|
| 449 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 450 |
+
(self.patch_size[1] * self.patch_size[2]) *
|
| 451 |
+
target_shape[1] / self.sp_size) * self.sp_size
|
| 452 |
+
|
| 453 |
+
if n_prompt == "":
|
| 454 |
+
n_prompt = self.config.generation.sample_neg_prompt
|
| 455 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 456 |
+
seed_g = torch.Generator(device=device)
|
| 457 |
+
seed_g.manual_seed(seed)
|
| 458 |
+
|
| 459 |
+
self.text_encoder.model.to(device)
|
| 460 |
+
context = self.text_encoder([input_prompt], device)
|
| 461 |
+
context_null = self.text_encoder([n_prompt], device)
|
| 462 |
+
self.text_encoder.model.cpu()
|
| 463 |
+
|
| 464 |
+
noise = [
|
| 465 |
+
torch.randn(
|
| 466 |
+
target_shape[0],
|
| 467 |
+
target_shape[1], # - latents_ref[0].shape[1],
|
| 468 |
+
target_shape[2],
|
| 469 |
+
target_shape[3],
|
| 470 |
+
dtype=torch.float32,
|
| 471 |
+
device=device,
|
| 472 |
+
generator=seed_g)
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
@contextmanager
|
| 476 |
+
def noop_no_sync():
|
| 477 |
+
yield
|
| 478 |
+
|
| 479 |
+
no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
|
| 480 |
+
# step_change = self.config.generation.step_change # 980
|
| 481 |
+
|
| 482 |
+
# evaluation mode
|
| 483 |
+
with amp.autocast(dtype=torch.bfloat16), torch.no_grad(), no_sync():
|
| 484 |
+
|
| 485 |
+
if sample_solver == 'unipc':
|
| 486 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 487 |
+
num_train_timesteps=1000,
|
| 488 |
+
shift=1,
|
| 489 |
+
use_dynamic_shifting=False)
|
| 490 |
+
sample_scheduler.set_timesteps(
|
| 491 |
+
sampling_steps, device=device, shift=shift)
|
| 492 |
+
timesteps = sample_scheduler.timesteps
|
| 493 |
+
|
| 494 |
+
# sample videos
|
| 495 |
+
latents = noise
|
| 496 |
+
|
| 497 |
+
# referene image在下面的输入中手动指定, 不在arg中指定
|
| 498 |
+
arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb}
|
| 499 |
+
arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg}
|
| 500 |
+
arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg}
|
| 501 |
+
|
| 502 |
+
torch.cuda.empty_cache()
|
| 503 |
+
self.dit.to(device=get_device())
|
| 504 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 505 |
+
timestep = [t]
|
| 506 |
+
timestep = torch.stack(timestep)
|
| 507 |
+
|
| 508 |
+
if self.config.generation.mode == "TIA":
|
| 509 |
+
noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
|
| 510 |
+
elif self.config.generation.mode == "TA":
|
| 511 |
+
noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null)
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
|
| 514 |
+
|
| 515 |
+
temp_x0 = sample_scheduler.step(
|
| 516 |
+
noise_pred.unsqueeze(0),
|
| 517 |
+
t,
|
| 518 |
+
latents[0].unsqueeze(0),
|
| 519 |
+
return_dict=False,
|
| 520 |
+
generator=seed_g)[0]
|
| 521 |
+
latents = [temp_x0.squeeze(0)]
|
| 522 |
+
|
| 523 |
+
del timestep
|
| 524 |
+
torch.cuda.empty_cache()
|
| 525 |
+
|
| 526 |
+
x0 = latents
|
| 527 |
+
x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
|
| 528 |
+
|
| 529 |
+
# if offload_model:
|
| 530 |
+
self.dit.cpu()
|
| 531 |
+
torch.cuda.empty_cache()
|
| 532 |
+
# if get_local_rank() == 0:
|
| 533 |
+
self.vae.model.to(device=device)
|
| 534 |
+
videos = self.vae.decode(x0)
|
| 535 |
+
self.vae.model.to(device="cpu")
|
| 536 |
+
|
| 537 |
+
del noise, latents, noise_pred
|
| 538 |
+
del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
|
| 539 |
+
del x0, temp_x0
|
| 540 |
+
del sample_scheduler
|
| 541 |
+
torch.cuda.empty_cache()
|
| 542 |
+
gc.collect()
|
| 543 |
+
torch.cuda.synchronize()
|
| 544 |
+
if dist.is_initialized():
|
| 545 |
+
dist.barrier()
|
| 546 |
+
|
| 547 |
+
return videos[0] # if get_local_rank() == 0 else None
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, width = 832, height = 480, steps=50, frames = 97, seed = 0):
|
| 551 |
+
print(f'ref_img_path:{ref_img_path}')
|
| 552 |
+
|
| 553 |
+
video = self.inference(
|
| 554 |
+
prompt,
|
| 555 |
+
ref_img_path,
|
| 556 |
+
audio_path,
|
| 557 |
+
size=SIZE_CONFIGS[f"{width}*{height}"],
|
| 558 |
+
frame_num=frames,
|
| 559 |
+
shift=self.config.diffusion.timesteps.sampling.shift,
|
| 560 |
+
sample_solver='unipc',
|
| 561 |
+
sampling_steps=steps,
|
| 562 |
+
seed=seed,
|
| 563 |
+
offload_model=False,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
torch.cuda.empty_cache()
|
| 567 |
+
gc.collect()
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# Save samples.
|
| 571 |
+
if get_sequence_parallel_rank() == 0:
|
| 572 |
+
pathname = self.save_sample(
|
| 573 |
+
sample=video,
|
| 574 |
+
audio_path=audio_path,
|
| 575 |
+
output_dir = output_dir,
|
| 576 |
+
filename=filename,
|
| 577 |
+
)
|
| 578 |
+
self.logger.info(f"Finished {filename}, saved to {pathname}.")
|
| 579 |
+
|
| 580 |
+
del video, prompt
|
| 581 |
+
torch.cuda.empty_cache()
|
| 582 |
+
gc.collect()
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
|
| 587 |
+
gen_config = self.config.generation
|
| 588 |
+
# Prepare file path.
|
| 589 |
+
extension = ".mp4" if sample.ndim == 4 else ".png"
|
| 590 |
+
filename += extension
|
| 591 |
+
pathname = os.path.join(output_dir, filename)
|
| 592 |
+
# Convert sample.
|
| 593 |
+
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
|
| 594 |
+
sample = rearrange(sample, "c t h w -> t h w c")
|
| 595 |
+
# Save file.
|
| 596 |
+
if sample.ndim == 4:
|
| 597 |
+
if audio_path is not None:
|
| 598 |
+
tensor_to_video(
|
| 599 |
+
sample.numpy(),
|
| 600 |
+
pathname,
|
| 601 |
+
audio_path,
|
| 602 |
+
fps=gen_config.fps)
|
| 603 |
+
else:
|
| 604 |
+
mediapy.write_video(
|
| 605 |
+
path=pathname,
|
| 606 |
+
images=sample.numpy(),
|
| 607 |
+
fps=gen_config.fps,
|
| 608 |
+
)
|
| 609 |
+
else:
|
| 610 |
+
raise ValueError
|
| 611 |
+
return pathname
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def prepare_positive_prompts(self):
|
| 615 |
+
pos_prompts = self.config.generation.positive_prompt
|
| 616 |
+
if pos_prompts.endswith(".json"):
|
| 617 |
+
pos_prompts = prepare_json_dataset(pos_prompts)
|
| 618 |
+
else:
|
| 619 |
+
raise NotImplementedError
|
| 620 |
+
assert isinstance(pos_prompts, ListConfig)
|
| 621 |
+
|
| 622 |
+
return pos_prompts
|
humo/models/audio/audio_proj.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import nn
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
class WanRMSNorm(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, dim, eps=1e-5):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dim = dim
|
| 11 |
+
self.eps = eps
|
| 12 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
r"""
|
| 16 |
+
Args:
|
| 17 |
+
x(Tensor): Shape [B, L, C]
|
| 18 |
+
"""
|
| 19 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 20 |
+
|
| 21 |
+
def _norm(self, x):
|
| 22 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DummyAdapterLayer(nn.Module):
|
| 26 |
+
def __init__(self, layer):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.layer = layer
|
| 29 |
+
|
| 30 |
+
def forward(self, *args, **kwargs):
|
| 31 |
+
return self.layer(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class AudioProjModel(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
seq_len=5,
|
| 38 |
+
blocks=13, # add a new parameter blocks
|
| 39 |
+
channels=768, # add a new parameter channels
|
| 40 |
+
intermediate_dim=512,
|
| 41 |
+
output_dim=1536,
|
| 42 |
+
context_tokens=16,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.seq_len = seq_len
|
| 47 |
+
self.blocks = blocks
|
| 48 |
+
self.channels = channels
|
| 49 |
+
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
|
| 50 |
+
self.intermediate_dim = intermediate_dim
|
| 51 |
+
self.context_tokens = context_tokens
|
| 52 |
+
self.output_dim = output_dim
|
| 53 |
+
|
| 54 |
+
# define multiple linear layers
|
| 55 |
+
self.audio_proj_glob_1 = DummyAdapterLayer(nn.Linear(self.input_dim, intermediate_dim))
|
| 56 |
+
self.audio_proj_glob_2 = DummyAdapterLayer(nn.Linear(intermediate_dim, intermediate_dim))
|
| 57 |
+
self.audio_proj_glob_3 = DummyAdapterLayer(nn.Linear(intermediate_dim, context_tokens * output_dim))
|
| 58 |
+
|
| 59 |
+
self.audio_proj_glob_norm = DummyAdapterLayer(nn.LayerNorm(output_dim))
|
| 60 |
+
|
| 61 |
+
self.initialize_weights()
|
| 62 |
+
|
| 63 |
+
def initialize_weights(self):
|
| 64 |
+
# Initialize transformer layers:
|
| 65 |
+
def _basic_init(module):
|
| 66 |
+
if isinstance(module, nn.Linear):
|
| 67 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 68 |
+
if module.bias is not None:
|
| 69 |
+
nn.init.constant_(module.bias, 0)
|
| 70 |
+
|
| 71 |
+
self.apply(_basic_init)
|
| 72 |
+
|
| 73 |
+
def forward(self, audio_embeds):
|
| 74 |
+
video_length = audio_embeds.shape[1]
|
| 75 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
| 76 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
| 77 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
| 78 |
+
|
| 79 |
+
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
|
| 80 |
+
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
|
| 81 |
+
|
| 82 |
+
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
|
| 83 |
+
|
| 84 |
+
context_tokens = self.audio_proj_glob_norm(context_tokens)
|
| 85 |
+
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
| 86 |
+
|
| 87 |
+
return context_tokens
|
humo/models/distributed/__init__.py
ADDED
|
File without changes
|
humo/models/distributed/dit_ulysses_sequence_parallel.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from common.distributed import get_device
|
| 15 |
+
|
| 16 |
+
from common.distributed.advanced import (
|
| 17 |
+
get_unified_parallel_world_size,
|
| 18 |
+
get_unified_parallel_group,
|
| 19 |
+
pad_tensor,
|
| 20 |
+
Slice,
|
| 21 |
+
gather_outputs,
|
| 22 |
+
gather_seq_scatter_heads_qkv,
|
| 23 |
+
gather_seq_scatter_double_head,
|
| 24 |
+
gather_heads_scatter_seq,
|
| 25 |
+
unpad_tensor
|
| 26 |
+
)
|
| 27 |
+
from humo.models.wan_modules.attention import flash_attention
|
| 28 |
+
from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def ulysses_dit_forward(
|
| 32 |
+
self,
|
| 33 |
+
x,
|
| 34 |
+
t,
|
| 35 |
+
context,
|
| 36 |
+
seq_len,
|
| 37 |
+
audio=None,
|
| 38 |
+
y=None
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 42 |
+
t: [B].
|
| 43 |
+
context: A list of text embeddings each with shape [L, C].
|
| 44 |
+
"""
|
| 45 |
+
if self.model_type == 'i2v':
|
| 46 |
+
# assert clip_fea is not None and y is not None
|
| 47 |
+
assert y is not None
|
| 48 |
+
# params
|
| 49 |
+
device = self.patch_embedding.weight.device
|
| 50 |
+
if self.freqs.device != device:
|
| 51 |
+
self.freqs = self.freqs.to(device)
|
| 52 |
+
|
| 53 |
+
if y is not None:
|
| 54 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 55 |
+
|
| 56 |
+
# embeddings
|
| 57 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 58 |
+
grid_sizes = torch.stack(
|
| 59 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 60 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 61 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device)
|
| 62 |
+
|
| 63 |
+
assert seq_lens.max() <= seq_len
|
| 64 |
+
x = torch.cat([
|
| 65 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 66 |
+
for u in x
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
# time embeddings
|
| 70 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 71 |
+
e = self.time_embedding(
|
| 72 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float()).float()
|
| 73 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
|
| 74 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 75 |
+
|
| 76 |
+
# context
|
| 77 |
+
context_lens = None
|
| 78 |
+
context = self.text_embedding(
|
| 79 |
+
torch.stack([
|
| 80 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 81 |
+
for u in context
|
| 82 |
+
]))
|
| 83 |
+
|
| 84 |
+
if self.insert_audio:
|
| 85 |
+
audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
|
| 86 |
+
|
| 87 |
+
audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device())
|
| 88 |
+
audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
|
| 89 |
+
audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device)
|
| 90 |
+
audio = torch.cat([
|
| 91 |
+
torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))],
|
| 92 |
+
dim=1) for au in audio
|
| 93 |
+
])
|
| 94 |
+
else:
|
| 95 |
+
audio = None
|
| 96 |
+
audio_seq_len = None
|
| 97 |
+
audio_seq_lens = None
|
| 98 |
+
|
| 99 |
+
# ulysses support
|
| 100 |
+
sp_world = get_unified_parallel_world_size()
|
| 101 |
+
group = get_unified_parallel_group()
|
| 102 |
+
if seq_len % sp_world:
|
| 103 |
+
padding_size = sp_world - (seq_len % sp_world)
|
| 104 |
+
x = pad_tensor(x, dim=1, padding_size=padding_size)
|
| 105 |
+
|
| 106 |
+
if self.insert_audio:
|
| 107 |
+
audio_padding_size = sp_world - (audio_seq_len % sp_world)
|
| 108 |
+
audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size)
|
| 109 |
+
|
| 110 |
+
x = Slice.apply(group, x, 1, True)
|
| 111 |
+
|
| 112 |
+
if self.insert_audio:
|
| 113 |
+
audio = Slice.apply(group, audio, 1, True)
|
| 114 |
+
|
| 115 |
+
# arguments
|
| 116 |
+
kwargs = dict(
|
| 117 |
+
e=e0,
|
| 118 |
+
seq_lens=seq_lens,
|
| 119 |
+
grid_sizes=grid_sizes,
|
| 120 |
+
freqs=self.freqs,
|
| 121 |
+
context=context,
|
| 122 |
+
context_lens=context_lens,
|
| 123 |
+
audio=audio,
|
| 124 |
+
audio_seq_len=audio_seq_len)
|
| 125 |
+
|
| 126 |
+
for block in self.blocks:
|
| 127 |
+
x = block(x, **kwargs)
|
| 128 |
+
|
| 129 |
+
# head
|
| 130 |
+
x = self.head(x, e)
|
| 131 |
+
|
| 132 |
+
# ulysses support
|
| 133 |
+
x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True)
|
| 134 |
+
|
| 135 |
+
# unpatchify
|
| 136 |
+
x = self.unpatchify(x, grid_sizes)
|
| 137 |
+
return [u.float() for u in x]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def ulysses_attn_forward(
|
| 141 |
+
self,
|
| 142 |
+
x,
|
| 143 |
+
seq_lens,
|
| 144 |
+
grid_sizes,
|
| 145 |
+
freqs,
|
| 146 |
+
dtype=torch.bfloat16
|
| 147 |
+
):
|
| 148 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 149 |
+
seq_len = seq_lens.max()
|
| 150 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 151 |
+
|
| 152 |
+
def half(x):
|
| 153 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 154 |
+
|
| 155 |
+
# query, key, value function
|
| 156 |
+
def qkv_fn(x):
|
| 157 |
+
q = self.norm_q(self.q(x))
|
| 158 |
+
k = self.norm_k(self.k(x))
|
| 159 |
+
v = self.v(x)
|
| 160 |
+
return q, k, v
|
| 161 |
+
|
| 162 |
+
q, k, v = qkv_fn(x)
|
| 163 |
+
|
| 164 |
+
# ulysses support
|
| 165 |
+
sp_size = get_unified_parallel_world_size()
|
| 166 |
+
if n % sp_size:
|
| 167 |
+
pad_size = sp_size - (n % sp_size)
|
| 168 |
+
pad_size = pad_size * d
|
| 169 |
+
pad_inner_dim = n * d + pad_size
|
| 170 |
+
q = pad_tensor(q, dim=2, padding_size=pad_size)
|
| 171 |
+
k = pad_tensor(k, dim=2, padding_size=pad_size)
|
| 172 |
+
v = pad_tensor(v, dim=2, padding_size=pad_size)
|
| 173 |
+
else:
|
| 174 |
+
pad_inner_dim = n * d
|
| 175 |
+
|
| 176 |
+
qkv = torch.cat([q, k, v], dim=2)
|
| 177 |
+
qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len)
|
| 178 |
+
q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2)
|
| 179 |
+
|
| 180 |
+
pad_n = pad_inner_dim // d
|
| 181 |
+
pad_split_n = pad_n // sp_size
|
| 182 |
+
q = q.view(b, seq_len, pad_split_n, d)
|
| 183 |
+
k = k.view(b, seq_len, pad_split_n, d)
|
| 184 |
+
v = v.view(b, seq_len, pad_split_n, d)
|
| 185 |
+
|
| 186 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 187 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 188 |
+
|
| 189 |
+
x = flash_attention(
|
| 190 |
+
q=half(q),
|
| 191 |
+
k=half(k),
|
| 192 |
+
v=half(v),
|
| 193 |
+
k_lens=seq_lens,
|
| 194 |
+
window_size=self.window_size
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# ulysses support
|
| 198 |
+
x = x.flatten(2)
|
| 199 |
+
x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
|
| 200 |
+
if n % sp_size:
|
| 201 |
+
x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
|
| 202 |
+
|
| 203 |
+
x = self.o(x)
|
| 204 |
+
return x
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def ulysses_audio_cross_attn_forward(
|
| 208 |
+
self,
|
| 209 |
+
x,
|
| 210 |
+
audio,
|
| 211 |
+
seq_lens,
|
| 212 |
+
grid_sizes,
|
| 213 |
+
freqs,
|
| 214 |
+
audio_seq_len,
|
| 215 |
+
dtype=torch.bfloat16
|
| 216 |
+
):
|
| 217 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 218 |
+
seq_len = seq_lens.max()
|
| 219 |
+
|
| 220 |
+
q = self.norm_q(self.q(x))
|
| 221 |
+
k = self.norm_k(self.k(audio))
|
| 222 |
+
v = self.v(audio)
|
| 223 |
+
|
| 224 |
+
# ulysses support
|
| 225 |
+
sp_size = get_unified_parallel_world_size()
|
| 226 |
+
if n % sp_size:
|
| 227 |
+
pad_size = sp_size - (n % sp_size)
|
| 228 |
+
pad_size = pad_size * d
|
| 229 |
+
pad_inner_dim = n * d + pad_size
|
| 230 |
+
q = pad_tensor(q, dim=2, padding_size=pad_size)
|
| 231 |
+
k = pad_tensor(k, dim=2, padding_size=pad_size)
|
| 232 |
+
v = pad_tensor(v, dim=2, padding_size=pad_size)
|
| 233 |
+
else:
|
| 234 |
+
pad_inner_dim = n * d
|
| 235 |
+
|
| 236 |
+
qq = torch.cat([q, q], dim=2)
|
| 237 |
+
kv = torch.cat([k, v], dim=2)
|
| 238 |
+
qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len)
|
| 239 |
+
kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len)
|
| 240 |
+
q, _ = qq.split(pad_inner_dim // sp_size, dim=2)
|
| 241 |
+
k, v = kv.split(pad_inner_dim // sp_size, dim=2)
|
| 242 |
+
|
| 243 |
+
pad_n = pad_inner_dim // d
|
| 244 |
+
pad_split_n = pad_n // sp_size
|
| 245 |
+
q = q.view(b, seq_len, pad_split_n, d)
|
| 246 |
+
k = k.view(b, audio_seq_len, pad_split_n, d)
|
| 247 |
+
v = v.view(b, audio_seq_len, pad_split_n, d)
|
| 248 |
+
|
| 249 |
+
hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
|
| 250 |
+
assert hlen_wlen == 1560 or hlen_wlen == 3600
|
| 251 |
+
q = q.reshape(-1, hlen_wlen, pad_split_n, d)
|
| 252 |
+
k = k.reshape(-1, 16, pad_split_n, d)
|
| 253 |
+
v = v.reshape(-1, 16, pad_split_n, d)
|
| 254 |
+
|
| 255 |
+
x = flash_attention(
|
| 256 |
+
q=q,
|
| 257 |
+
k=k,
|
| 258 |
+
v=v,
|
| 259 |
+
k_lens=None,
|
| 260 |
+
)
|
| 261 |
+
x = x.view(b, -1, pad_split_n, d)
|
| 262 |
+
|
| 263 |
+
# ulysses support
|
| 264 |
+
x = x.flatten(2)
|
| 265 |
+
x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
|
| 266 |
+
if n % sp_size:
|
| 267 |
+
x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
|
| 268 |
+
|
| 269 |
+
x = self.o(x)
|
| 270 |
+
return x
|
humo/models/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 16 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 17 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def shard_model(
|
| 21 |
+
model,
|
| 22 |
+
device_id,
|
| 23 |
+
param_dtype=torch.bfloat16,
|
| 24 |
+
reduce_dtype=torch.float32,
|
| 25 |
+
buffer_dtype=torch.float32,
|
| 26 |
+
process_group=None,
|
| 27 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 28 |
+
sync_module_states=True,
|
| 29 |
+
):
|
| 30 |
+
model = FSDP(
|
| 31 |
+
module=model,
|
| 32 |
+
process_group=process_group,
|
| 33 |
+
sharding_strategy=sharding_strategy,
|
| 34 |
+
auto_wrap_policy=partial(
|
| 35 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 36 |
+
mixed_precision=MixedPrecision(
|
| 37 |
+
param_dtype=param_dtype,
|
| 38 |
+
reduce_dtype=reduce_dtype,
|
| 39 |
+
buffer_dtype=buffer_dtype),
|
| 40 |
+
device_id=device_id,
|
| 41 |
+
sync_module_states=sync_module_states)
|
| 42 |
+
return model
|
humo/models/text/encoder.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Optional, Union
|
| 4 |
+
import torch
|
| 5 |
+
from omegaconf import DictConfig, OmegaConf
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForCausalLM,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTokenizerFast,
|
| 12 |
+
T5EncoderModel,
|
| 13 |
+
T5TokenizerFast,
|
| 14 |
+
)
|
| 15 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 16 |
+
|
| 17 |
+
from common.fs import download_and_extract
|
| 18 |
+
from common.logger import get_logger
|
| 19 |
+
|
| 20 |
+
logger = get_logger(__name__)
|
| 21 |
+
|
| 22 |
+
MODEL_TYPES = {
|
| 23 |
+
"clip": (CLIPTokenizerFast, CLIPTextModel),
|
| 24 |
+
"t5": (T5TokenizerFast, T5EncoderModel),
|
| 25 |
+
"llm14b": (AutoTokenizer, AutoModelForCausalLM),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class TextEncoderOutput:
|
| 31 |
+
embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]]
|
| 32 |
+
masks: Union[torch.BoolTensor, List[torch.BoolTensor]]
|
| 33 |
+
pooled: Optional[Union[torch.FloatTensor, List[torch.FloatTensor]]]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TextEncoder(nn.Module):
|
| 37 |
+
def __init__(self, config: DictConfig):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.config = config
|
| 40 |
+
self.tokenizers = []
|
| 41 |
+
self.models = nn.ModuleList([])
|
| 42 |
+
|
| 43 |
+
# Disable tokenizer parallelism since we already use distributed training.
|
| 44 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 45 |
+
|
| 46 |
+
for model in config.models:
|
| 47 |
+
tokenizer_cls, model_cls = MODEL_TYPES[model.type]
|
| 48 |
+
path = download_and_extract(model.path)
|
| 49 |
+
max_length = model.max_length
|
| 50 |
+
|
| 51 |
+
if model.type == "llm14b":
|
| 52 |
+
tokenizer = tokenizer_cls.from_pretrained(
|
| 53 |
+
path,
|
| 54 |
+
model_max_length=max_length,
|
| 55 |
+
use_fast=False,
|
| 56 |
+
trust_remote_code=True,
|
| 57 |
+
padding_side="right",
|
| 58 |
+
truncation_side="right",
|
| 59 |
+
add_eod_token=True,
|
| 60 |
+
)
|
| 61 |
+
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})
|
| 62 |
+
model = model_cls.from_pretrained(path, trust_remote_code=True, bf16=True)
|
| 63 |
+
else:
|
| 64 |
+
tokenizer = tokenizer_cls.from_pretrained(path, model_max_length=max_length)
|
| 65 |
+
model = model_cls.from_pretrained(path, torch_dtype=torch.bfloat16)
|
| 66 |
+
self.tokenizers.append(tokenizer)
|
| 67 |
+
self.models.append(model)
|
| 68 |
+
|
| 69 |
+
def forward(self, text: Union[str, List[str]]) -> TextEncoderOutput:
|
| 70 |
+
embeddings, masks, pooled = [], [], []
|
| 71 |
+
|
| 72 |
+
for encoder_config, tokenizer, model in zip(
|
| 73 |
+
self.config.models, self.tokenizers, self.models
|
| 74 |
+
):
|
| 75 |
+
if encoder_config.type == "llm14b":
|
| 76 |
+
use_mask = encoder_config.get("mask", True)
|
| 77 |
+
tokens = tokenizer(
|
| 78 |
+
text,
|
| 79 |
+
return_tensors="pt",
|
| 80 |
+
padding="max_length",
|
| 81 |
+
truncation=True,
|
| 82 |
+
).to(model.device)
|
| 83 |
+
token_ids = tokens["input_ids"]
|
| 84 |
+
attention_mask = tokens["attention_mask"]
|
| 85 |
+
num_tokens = attention_mask.sum(dim=1)
|
| 86 |
+
range_ids = torch.arange(len(token_ids), device=token_ids.device, dtype=torch.long)
|
| 87 |
+
token_ids[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = (
|
| 88 |
+
tokenizer.pad_token_id
|
| 89 |
+
)
|
| 90 |
+
attention_mask[range_ids, num_tokens.clamp(max=token_ids.size(1) - 1)] = 1
|
| 91 |
+
tokens = BatchEncoding({"input_ids": token_ids, "attention_mask": attention_mask})
|
| 92 |
+
output = model.transformer(
|
| 93 |
+
input_ids=tokens.input_ids,
|
| 94 |
+
attention_mask=attention_mask if use_mask else None,
|
| 95 |
+
output_hidden_states=False,
|
| 96 |
+
use_cache=False,
|
| 97 |
+
)
|
| 98 |
+
emb = output.last_hidden_state # batch_size, num_tokens, feat_dim
|
| 99 |
+
# emb *= tokens.attention_mask.unsqueeze(-1)
|
| 100 |
+
|
| 101 |
+
embeddings.append(emb)
|
| 102 |
+
masks.append(
|
| 103 |
+
tokens.attention_mask.bool() if use_mask else tokens.attention_mask > -1
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
# Tokenizer
|
| 108 |
+
tokens = tokenizer(
|
| 109 |
+
text=text,
|
| 110 |
+
truncation=True,
|
| 111 |
+
padding="max_length",
|
| 112 |
+
return_tensors="pt",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Encoder
|
| 116 |
+
use_mask = encoder_config.get("mask", True)
|
| 117 |
+
input_ids = tokens.input_ids.to(model.device)
|
| 118 |
+
attention_mask = tokens.attention_mask.to(model.device)
|
| 119 |
+
output = model(
|
| 120 |
+
input_ids=input_ids,
|
| 121 |
+
attention_mask=attention_mask if use_mask else None,
|
| 122 |
+
output_hidden_states=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Save embeddings from the defined layer.
|
| 126 |
+
layer = encoder_config.get("layer", "last")
|
| 127 |
+
if layer == "last":
|
| 128 |
+
embeddings.append(output.last_hidden_state)
|
| 129 |
+
elif layer == "penultimate":
|
| 130 |
+
embeddings.append(model.text_model.final_layer_norm(output.hidden_states[-2]))
|
| 131 |
+
elif layer == "penultimate_nonorm":
|
| 132 |
+
embeddings.append(output.hidden_states[-2])
|
| 133 |
+
else:
|
| 134 |
+
raise NotImplementedError(f"Unknown layer type: {layer}.")
|
| 135 |
+
|
| 136 |
+
# Save masks
|
| 137 |
+
masks.append(attention_mask.bool() if use_mask else attention_mask > -1)
|
| 138 |
+
|
| 139 |
+
# Save pooled output if available.
|
| 140 |
+
if hasattr(output, "pooler_output"):
|
| 141 |
+
pooled.append(output.pooler_output)
|
| 142 |
+
|
| 143 |
+
output_config = self.config.get("output") or OmegaConf.create()
|
| 144 |
+
embedding_output_type = output_config.get("embedding_and_mask", "undefined")
|
| 145 |
+
pooled_output_type = output_config.get("pooled", "undefined")
|
| 146 |
+
|
| 147 |
+
# Select or merge embeddings and mask if needed.
|
| 148 |
+
if embedding_output_type == "undefined" and len(self.models) == 1:
|
| 149 |
+
embeddings = embeddings[0]
|
| 150 |
+
masks = masks[0]
|
| 151 |
+
elif embedding_output_type == "channel_concat":
|
| 152 |
+
embeddings = torch.cat(embeddings, dim=-1)
|
| 153 |
+
masks = sum(masks).bool()
|
| 154 |
+
elif embedding_output_type == "last":
|
| 155 |
+
embeddings = embeddings[-1]
|
| 156 |
+
masks = masks[-1]
|
| 157 |
+
else:
|
| 158 |
+
raise NotImplementedError(f"output.embedding_and_mask: {embedding_output_type}")
|
| 159 |
+
|
| 160 |
+
# Select or merge pooled output if needed.
|
| 161 |
+
if pooled_output_type == "undefined":
|
| 162 |
+
pooled = None
|
| 163 |
+
elif pooled_output_type == "channel_concat":
|
| 164 |
+
pooled = torch.cat(pooled, dim=-1)
|
| 165 |
+
elif pooled_output_type == "first":
|
| 166 |
+
pooled = pooled[0]
|
| 167 |
+
elif pooled_output_type == "last":
|
| 168 |
+
pooled = pooled[-1]
|
| 169 |
+
else:
|
| 170 |
+
raise NotImplementedError(f"output.pooled: {pooled_output_type}")
|
| 171 |
+
|
| 172 |
+
# Return final results.
|
| 173 |
+
return TextEncoderOutput(embeddings, masks, pooled)
|
humo/models/utils/fm_solvers.py
ADDED
|
@@ -0,0 +1,857 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
|
| 2 |
+
# Convert dpm solver for flow matching
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import inspect
|
| 6 |
+
import math
|
| 7 |
+
from typing import List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
| 13 |
+
SchedulerMixin,
|
| 14 |
+
SchedulerOutput)
|
| 15 |
+
from diffusers.utils import deprecate, is_scipy_available
|
| 16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 17 |
+
|
| 18 |
+
if is_scipy_available():
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_sampling_sigmas(sampling_steps, shift):
|
| 23 |
+
sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps]
|
| 24 |
+
sigma = (shift * sigma / (1 + (shift - 1) * sigma))
|
| 25 |
+
|
| 26 |
+
return sigma
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def retrieve_timesteps(
|
| 30 |
+
scheduler,
|
| 31 |
+
num_inference_steps=None,
|
| 32 |
+
device=None,
|
| 33 |
+
timesteps=None,
|
| 34 |
+
sigmas=None,
|
| 35 |
+
**kwargs,
|
| 36 |
+
):
|
| 37 |
+
if timesteps is not None and sigmas is not None:
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
| 40 |
+
)
|
| 41 |
+
if timesteps is not None:
|
| 42 |
+
accepts_timesteps = "timesteps" in set(
|
| 43 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 44 |
+
if not accepts_timesteps:
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 47 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 48 |
+
)
|
| 49 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 50 |
+
timesteps = scheduler.timesteps
|
| 51 |
+
num_inference_steps = len(timesteps)
|
| 52 |
+
elif sigmas is not None:
|
| 53 |
+
accept_sigmas = "sigmas" in set(
|
| 54 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 55 |
+
if not accept_sigmas:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 58 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 59 |
+
)
|
| 60 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 61 |
+
timesteps = scheduler.timesteps
|
| 62 |
+
num_inference_steps = len(timesteps)
|
| 63 |
+
else:
|
| 64 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 65 |
+
timesteps = scheduler.timesteps
|
| 66 |
+
return timesteps, num_inference_steps
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 70 |
+
"""
|
| 71 |
+
`FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
|
| 72 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 73 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 74 |
+
Args:
|
| 75 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 76 |
+
The number of diffusion steps to train the model. This determines the resolution of the diffusion process.
|
| 77 |
+
solver_order (`int`, defaults to 2):
|
| 78 |
+
The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided
|
| 79 |
+
sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored
|
| 80 |
+
and used in multistep updates.
|
| 81 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
| 82 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
| 83 |
+
the flow of the diffusion process.
|
| 84 |
+
shift (`float`, *optional*, defaults to 1.0):
|
| 85 |
+
A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling
|
| 86 |
+
process.
|
| 87 |
+
use_dynamic_shifting (`bool`, defaults to `False`):
|
| 88 |
+
Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is
|
| 89 |
+
applied on the fly.
|
| 90 |
+
thresholding (`bool`, defaults to `False`):
|
| 91 |
+
Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent
|
| 92 |
+
saturation and improve photorealism.
|
| 93 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 94 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 95 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 96 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
| 97 |
+
`algorithm_type="dpmsolver++"`.
|
| 98 |
+
algorithm_type (`str`, defaults to `dpmsolver++`):
|
| 99 |
+
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
| 100 |
+
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
| 101 |
+
paper, and the `dpmsolver++` type implements the algorithms in the
|
| 102 |
+
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
| 103 |
+
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
| 104 |
+
solver_type (`str`, defaults to `midpoint`):
|
| 105 |
+
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
| 106 |
+
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
| 107 |
+
lower_order_final (`bool`, defaults to `True`):
|
| 108 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 109 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 110 |
+
euler_at_final (`bool`, defaults to `False`):
|
| 111 |
+
Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
|
| 112 |
+
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
|
| 113 |
+
steps, but sometimes may result in blurring.
|
| 114 |
+
final_sigmas_type (`str`, *optional*, defaults to "zero"):
|
| 115 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 116 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 117 |
+
lambda_min_clipped (`float`, defaults to `-inf`):
|
| 118 |
+
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
| 119 |
+
cosine (`squaredcos_cap_v2`) noise schedule.
|
| 120 |
+
variance_type (`str`, *optional*):
|
| 121 |
+
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
| 122 |
+
contains the predicted Gaussian variance.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 126 |
+
order = 1
|
| 127 |
+
|
| 128 |
+
@register_to_config
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
num_train_timesteps: int = 1000,
|
| 132 |
+
solver_order: int = 2,
|
| 133 |
+
prediction_type: str = "flow_prediction",
|
| 134 |
+
shift: Optional[float] = 1.0,
|
| 135 |
+
use_dynamic_shifting=False,
|
| 136 |
+
thresholding: bool = False,
|
| 137 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 138 |
+
sample_max_value: float = 1.0,
|
| 139 |
+
algorithm_type: str = "dpmsolver++",
|
| 140 |
+
solver_type: str = "midpoint",
|
| 141 |
+
lower_order_final: bool = True,
|
| 142 |
+
euler_at_final: bool = False,
|
| 143 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 144 |
+
lambda_min_clipped: float = -float("inf"),
|
| 145 |
+
variance_type: Optional[str] = None,
|
| 146 |
+
invert_sigmas: bool = False,
|
| 147 |
+
):
|
| 148 |
+
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 149 |
+
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
| 150 |
+
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0",
|
| 151 |
+
deprecation_message)
|
| 152 |
+
|
| 153 |
+
# settings for DPM-Solver
|
| 154 |
+
if algorithm_type not in [
|
| 155 |
+
"dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"
|
| 156 |
+
]:
|
| 157 |
+
if algorithm_type == "deis":
|
| 158 |
+
self.register_to_config(algorithm_type="dpmsolver++")
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError(
|
| 161 |
+
f"{algorithm_type} is not implemented for {self.__class__}")
|
| 162 |
+
|
| 163 |
+
if solver_type not in ["midpoint", "heun"]:
|
| 164 |
+
if solver_type in ["logrho", "bh1", "bh2"]:
|
| 165 |
+
self.register_to_config(solver_type="midpoint")
|
| 166 |
+
else:
|
| 167 |
+
raise NotImplementedError(
|
| 168 |
+
f"{solver_type} is not implemented for {self.__class__}")
|
| 169 |
+
|
| 170 |
+
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"
|
| 171 |
+
] and final_sigmas_type == "zero":
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# setable values
|
| 177 |
+
self.num_inference_steps = None
|
| 178 |
+
alphas = np.linspace(1, 1 / num_train_timesteps,
|
| 179 |
+
num_train_timesteps)[::-1].copy()
|
| 180 |
+
sigmas = 1.0 - alphas
|
| 181 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
| 182 |
+
|
| 183 |
+
if not use_dynamic_shifting:
|
| 184 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 185 |
+
sigmas = shift * sigmas / (1 +
|
| 186 |
+
(shift - 1) * sigmas) # pyright: ignore
|
| 187 |
+
|
| 188 |
+
self.sigmas = sigmas
|
| 189 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 190 |
+
|
| 191 |
+
self.model_outputs = [None] * solver_order
|
| 192 |
+
self.lower_order_nums = 0
|
| 193 |
+
self._step_index = None
|
| 194 |
+
self._begin_index = None
|
| 195 |
+
|
| 196 |
+
# self.sigmas = self.sigmas.to(
|
| 197 |
+
# "cpu") # to avoid too much CPU/GPU communication
|
| 198 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 199 |
+
self.sigma_max = self.sigmas[0].item()
|
| 200 |
+
|
| 201 |
+
@property
|
| 202 |
+
def step_index(self):
|
| 203 |
+
"""
|
| 204 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 205 |
+
"""
|
| 206 |
+
return self._step_index
|
| 207 |
+
|
| 208 |
+
@property
|
| 209 |
+
def begin_index(self):
|
| 210 |
+
"""
|
| 211 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 212 |
+
"""
|
| 213 |
+
return self._begin_index
|
| 214 |
+
|
| 215 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 216 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 217 |
+
"""
|
| 218 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 219 |
+
Args:
|
| 220 |
+
begin_index (`int`):
|
| 221 |
+
The begin index for the scheduler.
|
| 222 |
+
"""
|
| 223 |
+
self._begin_index = begin_index
|
| 224 |
+
|
| 225 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
| 226 |
+
def set_timesteps(
|
| 227 |
+
self,
|
| 228 |
+
num_inference_steps: Union[int, None] = None,
|
| 229 |
+
device: Union[str, torch.device] = None,
|
| 230 |
+
sigmas: Optional[List[float]] = None,
|
| 231 |
+
mu: Optional[Union[float, None]] = None,
|
| 232 |
+
shift: Optional[Union[float, None]] = None,
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 236 |
+
Args:
|
| 237 |
+
num_inference_steps (`int`):
|
| 238 |
+
Total number of the spacing of the time steps.
|
| 239 |
+
device (`str` or `torch.device`, *optional*):
|
| 240 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if sigmas is None:
|
| 249 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min,
|
| 250 |
+
num_inference_steps +
|
| 251 |
+
1).copy()[:-1] # pyright: ignore
|
| 252 |
+
|
| 253 |
+
if self.config.use_dynamic_shifting:
|
| 254 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
| 255 |
+
else:
|
| 256 |
+
if shift is None:
|
| 257 |
+
shift = self.config.shift
|
| 258 |
+
sigmas = shift * sigmas / (1 +
|
| 259 |
+
(shift - 1) * sigmas) # pyright: ignore
|
| 260 |
+
|
| 261 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 262 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) /
|
| 263 |
+
self.alphas_cumprod[0])**0.5
|
| 264 |
+
elif self.config.final_sigmas_type == "zero":
|
| 265 |
+
sigma_last = 0
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 272 |
+
sigmas = np.concatenate([sigmas, [sigma_last]
|
| 273 |
+
]).astype(np.float32) # pyright: ignore
|
| 274 |
+
|
| 275 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 276 |
+
self.timesteps = torch.from_numpy(timesteps).to(
|
| 277 |
+
device=device, dtype=torch.int64)
|
| 278 |
+
|
| 279 |
+
self.num_inference_steps = len(timesteps)
|
| 280 |
+
|
| 281 |
+
self.model_outputs = [
|
| 282 |
+
None,
|
| 283 |
+
] * self.config.solver_order
|
| 284 |
+
self.lower_order_nums = 0
|
| 285 |
+
|
| 286 |
+
self._step_index = None
|
| 287 |
+
self._begin_index = None
|
| 288 |
+
# self.sigmas = self.sigmas.to(
|
| 289 |
+
# "cpu") # to avoid too much CPU/GPU communication
|
| 290 |
+
|
| 291 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 292 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 293 |
+
"""
|
| 294 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 295 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 296 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 297 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 298 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 299 |
+
https://arxiv.org/abs/2205.11487
|
| 300 |
+
"""
|
| 301 |
+
dtype = sample.dtype
|
| 302 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 303 |
+
|
| 304 |
+
if dtype not in (torch.float32, torch.float64):
|
| 305 |
+
sample = sample.float(
|
| 306 |
+
) # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 307 |
+
|
| 308 |
+
# Flatten sample for doing quantile calculation along each image
|
| 309 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 310 |
+
|
| 311 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 312 |
+
|
| 313 |
+
s = torch.quantile(
|
| 314 |
+
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 315 |
+
s = torch.clamp(
|
| 316 |
+
s, min=1, max=self.config.sample_max_value
|
| 317 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 318 |
+
s = s.unsqueeze(
|
| 319 |
+
1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 320 |
+
sample = torch.clamp(
|
| 321 |
+
sample, -s, s
|
| 322 |
+
) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 323 |
+
|
| 324 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 325 |
+
sample = sample.to(dtype)
|
| 326 |
+
|
| 327 |
+
return sample
|
| 328 |
+
|
| 329 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
| 330 |
+
def _sigma_to_t(self, sigma):
|
| 331 |
+
return sigma * self.config.num_train_timesteps
|
| 332 |
+
|
| 333 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 334 |
+
return 1 - sigma, sigma
|
| 335 |
+
|
| 336 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
| 337 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 338 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
| 339 |
+
|
| 340 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
|
| 341 |
+
def convert_model_output(
|
| 342 |
+
self,
|
| 343 |
+
model_output: torch.Tensor,
|
| 344 |
+
*args,
|
| 345 |
+
sample: torch.Tensor = None,
|
| 346 |
+
**kwargs,
|
| 347 |
+
) -> torch.Tensor:
|
| 348 |
+
"""
|
| 349 |
+
Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is
|
| 350 |
+
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
|
| 351 |
+
integral of the data prediction model.
|
| 352 |
+
<Tip>
|
| 353 |
+
The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
|
| 354 |
+
prediction and data prediction models.
|
| 355 |
+
</Tip>
|
| 356 |
+
Args:
|
| 357 |
+
model_output (`torch.Tensor`):
|
| 358 |
+
The direct output from the learned diffusion model.
|
| 359 |
+
sample (`torch.Tensor`):
|
| 360 |
+
A current instance of a sample created by the diffusion process.
|
| 361 |
+
Returns:
|
| 362 |
+
`torch.Tensor`:
|
| 363 |
+
The converted model output.
|
| 364 |
+
"""
|
| 365 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 366 |
+
if sample is None:
|
| 367 |
+
if len(args) > 1:
|
| 368 |
+
sample = args[1]
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"missing `sample` as a required keyward argument")
|
| 372 |
+
if timestep is not None:
|
| 373 |
+
deprecate(
|
| 374 |
+
"timesteps",
|
| 375 |
+
"1.0.0",
|
| 376 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
| 380 |
+
if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]:
|
| 381 |
+
if self.config.prediction_type == "flow_prediction":
|
| 382 |
+
sigma_t = self.sigmas[self.step_index]
|
| 383 |
+
x0_pred = sample - sigma_t * model_output
|
| 384 |
+
else:
|
| 385 |
+
raise ValueError(
|
| 386 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 387 |
+
" `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if self.config.thresholding:
|
| 391 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 392 |
+
|
| 393 |
+
return x0_pred
|
| 394 |
+
|
| 395 |
+
# DPM-Solver needs to solve an integral of the noise prediction model.
|
| 396 |
+
elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
| 397 |
+
if self.config.prediction_type == "flow_prediction":
|
| 398 |
+
sigma_t = self.sigmas[self.step_index]
|
| 399 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError(
|
| 402 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 403 |
+
" `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler."
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if self.config.thresholding:
|
| 407 |
+
sigma_t = self.sigmas[self.step_index]
|
| 408 |
+
x0_pred = sample - sigma_t * model_output
|
| 409 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 410 |
+
epsilon = model_output + x0_pred
|
| 411 |
+
|
| 412 |
+
return epsilon
|
| 413 |
+
|
| 414 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update
|
| 415 |
+
def dpm_solver_first_order_update(
|
| 416 |
+
self,
|
| 417 |
+
model_output: torch.Tensor,
|
| 418 |
+
*args,
|
| 419 |
+
sample: torch.Tensor = None,
|
| 420 |
+
noise: Optional[torch.Tensor] = None,
|
| 421 |
+
**kwargs,
|
| 422 |
+
) -> torch.Tensor:
|
| 423 |
+
"""
|
| 424 |
+
One step for the first-order DPMSolver (equivalent to DDIM).
|
| 425 |
+
Args:
|
| 426 |
+
model_output (`torch.Tensor`):
|
| 427 |
+
The direct output from the learned diffusion model.
|
| 428 |
+
sample (`torch.Tensor`):
|
| 429 |
+
A current instance of a sample created by the diffusion process.
|
| 430 |
+
Returns:
|
| 431 |
+
`torch.Tensor`:
|
| 432 |
+
The sample tensor at the previous timestep.
|
| 433 |
+
"""
|
| 434 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 435 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
|
| 436 |
+
"prev_timestep", None)
|
| 437 |
+
if sample is None:
|
| 438 |
+
if len(args) > 2:
|
| 439 |
+
sample = args[2]
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
" missing `sample` as a required keyward argument")
|
| 443 |
+
if timestep is not None:
|
| 444 |
+
deprecate(
|
| 445 |
+
"timesteps",
|
| 446 |
+
"1.0.0",
|
| 447 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
if prev_timestep is not None:
|
| 451 |
+
deprecate(
|
| 452 |
+
"prev_timestep",
|
| 453 |
+
"1.0.0",
|
| 454 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[
|
| 458 |
+
self.step_index] # pyright: ignore
|
| 459 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 460 |
+
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
| 461 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 462 |
+
lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
|
| 463 |
+
|
| 464 |
+
h = lambda_t - lambda_s
|
| 465 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 466 |
+
x_t = (sigma_t /
|
| 467 |
+
sigma_s) * sample - (alpha_t *
|
| 468 |
+
(torch.exp(-h) - 1.0)) * model_output
|
| 469 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 470 |
+
x_t = (alpha_t /
|
| 471 |
+
alpha_s) * sample - (sigma_t *
|
| 472 |
+
(torch.exp(h) - 1.0)) * model_output
|
| 473 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 474 |
+
assert noise is not None
|
| 475 |
+
x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample +
|
| 476 |
+
(alpha_t * (1 - torch.exp(-2.0 * h))) * model_output +
|
| 477 |
+
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
|
| 478 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 479 |
+
assert noise is not None
|
| 480 |
+
x_t = ((alpha_t / alpha_s) * sample - 2.0 *
|
| 481 |
+
(sigma_t * (torch.exp(h) - 1.0)) * model_output +
|
| 482 |
+
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
|
| 483 |
+
return x_t # pyright: ignore
|
| 484 |
+
|
| 485 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update
|
| 486 |
+
def multistep_dpm_solver_second_order_update(
|
| 487 |
+
self,
|
| 488 |
+
model_output_list: List[torch.Tensor],
|
| 489 |
+
*args,
|
| 490 |
+
sample: torch.Tensor = None,
|
| 491 |
+
noise: Optional[torch.Tensor] = None,
|
| 492 |
+
**kwargs,
|
| 493 |
+
) -> torch.Tensor:
|
| 494 |
+
"""
|
| 495 |
+
One step for the second-order multistep DPMSolver.
|
| 496 |
+
Args:
|
| 497 |
+
model_output_list (`List[torch.Tensor]`):
|
| 498 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 499 |
+
sample (`torch.Tensor`):
|
| 500 |
+
A current instance of a sample created by the diffusion process.
|
| 501 |
+
Returns:
|
| 502 |
+
`torch.Tensor`:
|
| 503 |
+
The sample tensor at the previous timestep.
|
| 504 |
+
"""
|
| 505 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop(
|
| 506 |
+
"timestep_list", None)
|
| 507 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
|
| 508 |
+
"prev_timestep", None)
|
| 509 |
+
if sample is None:
|
| 510 |
+
if len(args) > 2:
|
| 511 |
+
sample = args[2]
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(
|
| 514 |
+
" missing `sample` as a required keyward argument")
|
| 515 |
+
if timestep_list is not None:
|
| 516 |
+
deprecate(
|
| 517 |
+
"timestep_list",
|
| 518 |
+
"1.0.0",
|
| 519 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if prev_timestep is not None:
|
| 523 |
+
deprecate(
|
| 524 |
+
"prev_timestep",
|
| 525 |
+
"1.0.0",
|
| 526 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
sigma_t, sigma_s0, sigma_s1 = (
|
| 530 |
+
self.sigmas[self.step_index + 1], # pyright: ignore
|
| 531 |
+
self.sigmas[self.step_index],
|
| 532 |
+
self.sigmas[self.step_index - 1], # pyright: ignore
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 536 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 537 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 538 |
+
|
| 539 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 540 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 541 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 542 |
+
|
| 543 |
+
m0, m1 = model_output_list[-1], model_output_list[-2]
|
| 544 |
+
|
| 545 |
+
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
| 546 |
+
r0 = h_0 / h
|
| 547 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
| 548 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 549 |
+
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
| 550 |
+
if self.config.solver_type == "midpoint":
|
| 551 |
+
x_t = ((sigma_t / sigma_s0) * sample -
|
| 552 |
+
(alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 *
|
| 553 |
+
(alpha_t * (torch.exp(-h) - 1.0)) * D1)
|
| 554 |
+
elif self.config.solver_type == "heun":
|
| 555 |
+
x_t = ((sigma_t / sigma_s0) * sample -
|
| 556 |
+
(alpha_t * (torch.exp(-h) - 1.0)) * D0 +
|
| 557 |
+
(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1)
|
| 558 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 559 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 560 |
+
if self.config.solver_type == "midpoint":
|
| 561 |
+
x_t = ((alpha_t / alpha_s0) * sample -
|
| 562 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 *
|
| 563 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D1)
|
| 564 |
+
elif self.config.solver_type == "heun":
|
| 565 |
+
x_t = ((alpha_t / alpha_s0) * sample -
|
| 566 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D0 -
|
| 567 |
+
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1)
|
| 568 |
+
elif self.config.algorithm_type == "sde-dpmsolver++":
|
| 569 |
+
assert noise is not None
|
| 570 |
+
if self.config.solver_type == "midpoint":
|
| 571 |
+
x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
|
| 572 |
+
(alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 *
|
| 573 |
+
(alpha_t * (1 - torch.exp(-2.0 * h))) * D1 +
|
| 574 |
+
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
|
| 575 |
+
elif self.config.solver_type == "heun":
|
| 576 |
+
x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample +
|
| 577 |
+
(alpha_t * (1 - torch.exp(-2.0 * h))) * D0 +
|
| 578 |
+
(alpha_t * ((1.0 - torch.exp(-2.0 * h)) /
|
| 579 |
+
(-2.0 * h) + 1.0)) * D1 +
|
| 580 |
+
sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)
|
| 581 |
+
elif self.config.algorithm_type == "sde-dpmsolver":
|
| 582 |
+
assert noise is not None
|
| 583 |
+
if self.config.solver_type == "midpoint":
|
| 584 |
+
x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
|
| 585 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D0 -
|
| 586 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D1 +
|
| 587 |
+
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
|
| 588 |
+
elif self.config.solver_type == "heun":
|
| 589 |
+
x_t = ((alpha_t / alpha_s0) * sample - 2.0 *
|
| 590 |
+
(sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 *
|
| 591 |
+
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 +
|
| 592 |
+
sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)
|
| 593 |
+
return x_t # pyright: ignore
|
| 594 |
+
|
| 595 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update
|
| 596 |
+
def multistep_dpm_solver_third_order_update(
|
| 597 |
+
self,
|
| 598 |
+
model_output_list: List[torch.Tensor],
|
| 599 |
+
*args,
|
| 600 |
+
sample: torch.Tensor = None,
|
| 601 |
+
**kwargs,
|
| 602 |
+
) -> torch.Tensor:
|
| 603 |
+
"""
|
| 604 |
+
One step for the third-order multistep DPMSolver.
|
| 605 |
+
Args:
|
| 606 |
+
model_output_list (`List[torch.Tensor]`):
|
| 607 |
+
The direct outputs from learned diffusion model at current and latter timesteps.
|
| 608 |
+
sample (`torch.Tensor`):
|
| 609 |
+
A current instance of a sample created by diffusion process.
|
| 610 |
+
Returns:
|
| 611 |
+
`torch.Tensor`:
|
| 612 |
+
The sample tensor at the previous timestep.
|
| 613 |
+
"""
|
| 614 |
+
|
| 615 |
+
timestep_list = args[0] if len(args) > 0 else kwargs.pop(
|
| 616 |
+
"timestep_list", None)
|
| 617 |
+
prev_timestep = args[1] if len(args) > 1 else kwargs.pop(
|
| 618 |
+
"prev_timestep", None)
|
| 619 |
+
if sample is None:
|
| 620 |
+
if len(args) > 2:
|
| 621 |
+
sample = args[2]
|
| 622 |
+
else:
|
| 623 |
+
raise ValueError(
|
| 624 |
+
" missing`sample` as a required keyward argument")
|
| 625 |
+
if timestep_list is not None:
|
| 626 |
+
deprecate(
|
| 627 |
+
"timestep_list",
|
| 628 |
+
"1.0.0",
|
| 629 |
+
"Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if prev_timestep is not None:
|
| 633 |
+
deprecate(
|
| 634 |
+
"prev_timestep",
|
| 635 |
+
"1.0.0",
|
| 636 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
sigma_t, sigma_s0, sigma_s1, sigma_s2 = (
|
| 640 |
+
self.sigmas[self.step_index + 1], # pyright: ignore
|
| 641 |
+
self.sigmas[self.step_index],
|
| 642 |
+
self.sigmas[self.step_index - 1], # pyright: ignore
|
| 643 |
+
self.sigmas[self.step_index - 2], # pyright: ignore
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 647 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 648 |
+
alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1)
|
| 649 |
+
alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2)
|
| 650 |
+
|
| 651 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 652 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 653 |
+
lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
|
| 654 |
+
lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2)
|
| 655 |
+
|
| 656 |
+
m0, m1, m2 = model_output_list[-1], model_output_list[
|
| 657 |
+
-2], model_output_list[-3]
|
| 658 |
+
|
| 659 |
+
h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2
|
| 660 |
+
r0, r1 = h_0 / h, h_1 / h
|
| 661 |
+
D0 = m0
|
| 662 |
+
D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2)
|
| 663 |
+
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 664 |
+
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
|
| 665 |
+
if self.config.algorithm_type == "dpmsolver++":
|
| 666 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 667 |
+
x_t = ((sigma_t / sigma_s0) * sample -
|
| 668 |
+
(alpha_t * (torch.exp(-h) - 1.0)) * D0 +
|
| 669 |
+
(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 -
|
| 670 |
+
(alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2)
|
| 671 |
+
elif self.config.algorithm_type == "dpmsolver":
|
| 672 |
+
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
| 673 |
+
x_t = ((alpha_t / alpha_s0) * sample - (sigma_t *
|
| 674 |
+
(torch.exp(h) - 1.0)) * D0 -
|
| 675 |
+
(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 -
|
| 676 |
+
(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2)
|
| 677 |
+
return x_t # pyright: ignore
|
| 678 |
+
|
| 679 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 680 |
+
if schedule_timesteps is None:
|
| 681 |
+
schedule_timesteps = self.timesteps
|
| 682 |
+
|
| 683 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 684 |
+
|
| 685 |
+
# The sigma index that is taken for the **very** first `step`
|
| 686 |
+
# is always the second index (or the last index if there is only 1)
|
| 687 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 688 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 689 |
+
pos = 1 if len(indices) > 1 else 0
|
| 690 |
+
|
| 691 |
+
return indices[pos].item()
|
| 692 |
+
|
| 693 |
+
def _init_step_index(self, timestep):
|
| 694 |
+
"""
|
| 695 |
+
Initialize the step_index counter for the scheduler.
|
| 696 |
+
"""
|
| 697 |
+
|
| 698 |
+
if self.begin_index is None:
|
| 699 |
+
if isinstance(timestep, torch.Tensor):
|
| 700 |
+
timestep = timestep.to(self.timesteps.device)
|
| 701 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 702 |
+
else:
|
| 703 |
+
self._step_index = self._begin_index
|
| 704 |
+
|
| 705 |
+
# Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
|
| 706 |
+
def step(
|
| 707 |
+
self,
|
| 708 |
+
model_output: torch.Tensor,
|
| 709 |
+
timestep: Union[int, torch.Tensor],
|
| 710 |
+
sample: torch.Tensor,
|
| 711 |
+
generator=None,
|
| 712 |
+
variance_noise: Optional[torch.Tensor] = None,
|
| 713 |
+
return_dict: bool = True,
|
| 714 |
+
) -> Union[SchedulerOutput, Tuple]:
|
| 715 |
+
"""
|
| 716 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 717 |
+
the multistep DPMSolver.
|
| 718 |
+
Args:
|
| 719 |
+
model_output (`torch.Tensor`):
|
| 720 |
+
The direct output from learned diffusion model.
|
| 721 |
+
timestep (`int`):
|
| 722 |
+
The current discrete timestep in the diffusion chain.
|
| 723 |
+
sample (`torch.Tensor`):
|
| 724 |
+
A current instance of a sample created by the diffusion process.
|
| 725 |
+
generator (`torch.Generator`, *optional*):
|
| 726 |
+
A random number generator.
|
| 727 |
+
variance_noise (`torch.Tensor`):
|
| 728 |
+
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
| 729 |
+
itself. Useful for methods such as [`LEdits++`].
|
| 730 |
+
return_dict (`bool`):
|
| 731 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 732 |
+
Returns:
|
| 733 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 734 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 735 |
+
tuple is returned where the first element is the sample tensor.
|
| 736 |
+
"""
|
| 737 |
+
if self.num_inference_steps is None:
|
| 738 |
+
raise ValueError(
|
| 739 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
if self.step_index is None:
|
| 743 |
+
self._init_step_index(timestep)
|
| 744 |
+
|
| 745 |
+
# Improve numerical stability for small number of steps
|
| 746 |
+
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
| 747 |
+
self.config.euler_at_final or
|
| 748 |
+
(self.config.lower_order_final and len(self.timesteps) < 15) or
|
| 749 |
+
self.config.final_sigmas_type == "zero")
|
| 750 |
+
lower_order_second = ((self.step_index == len(self.timesteps) - 2) and
|
| 751 |
+
self.config.lower_order_final and
|
| 752 |
+
len(self.timesteps) < 15)
|
| 753 |
+
|
| 754 |
+
model_output = self.convert_model_output(model_output, sample=sample)
|
| 755 |
+
for i in range(self.config.solver_order - 1):
|
| 756 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 757 |
+
self.model_outputs[-1] = model_output
|
| 758 |
+
|
| 759 |
+
# Upcast to avoid precision issues when computing prev_sample
|
| 760 |
+
sample = sample.to(torch.float32)
|
| 761 |
+
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"
|
| 762 |
+
] and variance_noise is None:
|
| 763 |
+
noise = randn_tensor(
|
| 764 |
+
model_output.shape,
|
| 765 |
+
generator=generator,
|
| 766 |
+
device=model_output.device,
|
| 767 |
+
dtype=torch.float32)
|
| 768 |
+
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
| 769 |
+
noise = variance_noise.to(
|
| 770 |
+
device=model_output.device,
|
| 771 |
+
dtype=torch.float32) # pyright: ignore
|
| 772 |
+
else:
|
| 773 |
+
noise = None
|
| 774 |
+
|
| 775 |
+
if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:
|
| 776 |
+
prev_sample = self.dpm_solver_first_order_update(
|
| 777 |
+
model_output, sample=sample, noise=noise)
|
| 778 |
+
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
| 779 |
+
prev_sample = self.multistep_dpm_solver_second_order_update(
|
| 780 |
+
self.model_outputs, sample=sample, noise=noise)
|
| 781 |
+
else:
|
| 782 |
+
prev_sample = self.multistep_dpm_solver_third_order_update(
|
| 783 |
+
self.model_outputs, sample=sample)
|
| 784 |
+
|
| 785 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 786 |
+
self.lower_order_nums += 1
|
| 787 |
+
|
| 788 |
+
# Cast sample back to expected dtype
|
| 789 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
| 790 |
+
|
| 791 |
+
# upon completion increase step index by one
|
| 792 |
+
self._step_index += 1 # pyright: ignore
|
| 793 |
+
|
| 794 |
+
if not return_dict:
|
| 795 |
+
return (prev_sample,)
|
| 796 |
+
|
| 797 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 798 |
+
|
| 799 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
|
| 800 |
+
def scale_model_input(self, sample: torch.Tensor, *args,
|
| 801 |
+
**kwargs) -> torch.Tensor:
|
| 802 |
+
"""
|
| 803 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 804 |
+
current timestep.
|
| 805 |
+
Args:
|
| 806 |
+
sample (`torch.Tensor`):
|
| 807 |
+
The input sample.
|
| 808 |
+
Returns:
|
| 809 |
+
`torch.Tensor`:
|
| 810 |
+
A scaled input sample.
|
| 811 |
+
"""
|
| 812 |
+
return sample
|
| 813 |
+
|
| 814 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input
|
| 815 |
+
def add_noise(
|
| 816 |
+
self,
|
| 817 |
+
original_samples: torch.Tensor,
|
| 818 |
+
noise: torch.Tensor,
|
| 819 |
+
timesteps: torch.IntTensor,
|
| 820 |
+
) -> torch.Tensor:
|
| 821 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 822 |
+
sigmas = self.sigmas.to(
|
| 823 |
+
device=original_samples.device, dtype=original_samples.dtype)
|
| 824 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(
|
| 825 |
+
timesteps):
|
| 826 |
+
# mps does not support float64
|
| 827 |
+
schedule_timesteps = self.timesteps.to(
|
| 828 |
+
original_samples.device, dtype=torch.float32)
|
| 829 |
+
timesteps = timesteps.to(
|
| 830 |
+
original_samples.device, dtype=torch.float32)
|
| 831 |
+
else:
|
| 832 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 833 |
+
timesteps = timesteps.to(original_samples.device)
|
| 834 |
+
|
| 835 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 836 |
+
if self.begin_index is None:
|
| 837 |
+
step_indices = [
|
| 838 |
+
self.index_for_timestep(t, schedule_timesteps)
|
| 839 |
+
for t in timesteps
|
| 840 |
+
]
|
| 841 |
+
elif self.step_index is not None:
|
| 842 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 843 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 844 |
+
else:
|
| 845 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 846 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 847 |
+
|
| 848 |
+
sigma = sigmas[step_indices].flatten()
|
| 849 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 850 |
+
sigma = sigma.unsqueeze(-1)
|
| 851 |
+
|
| 852 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 853 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 854 |
+
return noisy_samples
|
| 855 |
+
|
| 856 |
+
def __len__(self):
|
| 857 |
+
return self.config.num_train_timesteps
|
humo/models/utils/fm_solvers_unipc.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
|
| 2 |
+
# Convert unipc for flow matching
|
| 3 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import List, Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 11 |
+
from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers,
|
| 12 |
+
SchedulerMixin,
|
| 13 |
+
SchedulerOutput)
|
| 14 |
+
from diffusers.utils import deprecate, is_scipy_available
|
| 15 |
+
|
| 16 |
+
if is_scipy_available():
|
| 17 |
+
import scipy.stats
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
| 21 |
+
"""
|
| 22 |
+
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
|
| 23 |
+
|
| 24 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
| 25 |
+
methods the library implements for all schedulers such as loading and saving.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
num_train_timesteps (`int`, defaults to 1000):
|
| 29 |
+
The number of diffusion steps to train the model.
|
| 30 |
+
solver_order (`int`, default `2`):
|
| 31 |
+
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
|
| 32 |
+
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
|
| 33 |
+
unconditional sampling.
|
| 34 |
+
prediction_type (`str`, defaults to "flow_prediction"):
|
| 35 |
+
Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts
|
| 36 |
+
the flow of the diffusion process.
|
| 37 |
+
thresholding (`bool`, defaults to `False`):
|
| 38 |
+
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
| 39 |
+
as Stable Diffusion.
|
| 40 |
+
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
| 41 |
+
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
| 42 |
+
sample_max_value (`float`, defaults to 1.0):
|
| 43 |
+
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
|
| 44 |
+
predict_x0 (`bool`, defaults to `True`):
|
| 45 |
+
Whether to use the updating algorithm on the predicted x0.
|
| 46 |
+
solver_type (`str`, default `bh2`):
|
| 47 |
+
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
|
| 48 |
+
otherwise.
|
| 49 |
+
lower_order_final (`bool`, default `True`):
|
| 50 |
+
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
| 51 |
+
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
| 52 |
+
disable_corrector (`list`, default `[]`):
|
| 53 |
+
Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)`
|
| 54 |
+
and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is
|
| 55 |
+
usually disabled during the first few steps.
|
| 56 |
+
solver_p (`SchedulerMixin`, default `None`):
|
| 57 |
+
Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`.
|
| 58 |
+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
| 59 |
+
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
| 60 |
+
the sigmas are determined according to a sequence of noise levels {σi}.
|
| 61 |
+
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
| 62 |
+
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
| 63 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
| 64 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
| 65 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
| 66 |
+
steps_offset (`int`, defaults to 0):
|
| 67 |
+
An offset added to the inference steps, as required by some model families.
|
| 68 |
+
final_sigmas_type (`str`, defaults to `"zero"`):
|
| 69 |
+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
| 70 |
+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
| 74 |
+
order = 1
|
| 75 |
+
|
| 76 |
+
@register_to_config
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
num_train_timesteps: int = 1000,
|
| 80 |
+
solver_order: int = 2,
|
| 81 |
+
prediction_type: str = "flow_prediction",
|
| 82 |
+
shift: Optional[float] = 1.0,
|
| 83 |
+
use_dynamic_shifting=False,
|
| 84 |
+
thresholding: bool = False,
|
| 85 |
+
dynamic_thresholding_ratio: float = 0.995,
|
| 86 |
+
sample_max_value: float = 1.0,
|
| 87 |
+
predict_x0: bool = True,
|
| 88 |
+
solver_type: str = "bh2",
|
| 89 |
+
lower_order_final: bool = True,
|
| 90 |
+
disable_corrector: List[int] = [],
|
| 91 |
+
solver_p: SchedulerMixin = None,
|
| 92 |
+
timestep_spacing: str = "linspace",
|
| 93 |
+
steps_offset: int = 0,
|
| 94 |
+
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
| 95 |
+
):
|
| 96 |
+
|
| 97 |
+
if solver_type not in ["bh1", "bh2"]:
|
| 98 |
+
if solver_type in ["midpoint", "heun", "logrho"]:
|
| 99 |
+
self.register_to_config(solver_type="bh2")
|
| 100 |
+
else:
|
| 101 |
+
raise NotImplementedError(
|
| 102 |
+
f"{solver_type} is not implemented for {self.__class__}")
|
| 103 |
+
|
| 104 |
+
self.predict_x0 = predict_x0
|
| 105 |
+
# setable values
|
| 106 |
+
self.num_inference_steps = None
|
| 107 |
+
alphas = np.linspace(1, 1 / num_train_timesteps,
|
| 108 |
+
num_train_timesteps)[::-1].copy()
|
| 109 |
+
sigmas = 1.0 - alphas
|
| 110 |
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
|
| 111 |
+
|
| 112 |
+
if not use_dynamic_shifting:
|
| 113 |
+
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
| 114 |
+
sigmas = shift * sigmas / (1 +
|
| 115 |
+
(shift - 1) * sigmas) # pyright: ignore
|
| 116 |
+
|
| 117 |
+
self.sigmas = sigmas
|
| 118 |
+
self.timesteps = sigmas * num_train_timesteps
|
| 119 |
+
|
| 120 |
+
self.model_outputs = [None] * solver_order
|
| 121 |
+
self.timestep_list = [None] * solver_order
|
| 122 |
+
self.lower_order_nums = 0
|
| 123 |
+
self.disable_corrector = disable_corrector
|
| 124 |
+
self.solver_p = solver_p
|
| 125 |
+
self.last_sample = None
|
| 126 |
+
self._step_index = None
|
| 127 |
+
self._begin_index = None
|
| 128 |
+
|
| 129 |
+
self.sigmas = self.sigmas.to(
|
| 130 |
+
"cpu") # to avoid too much CPU/GPU communication
|
| 131 |
+
self.sigma_min = self.sigmas[-1].item()
|
| 132 |
+
self.sigma_max = self.sigmas[0].item()
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def step_index(self):
|
| 136 |
+
"""
|
| 137 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
| 138 |
+
"""
|
| 139 |
+
return self._step_index
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def begin_index(self):
|
| 143 |
+
"""
|
| 144 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
| 145 |
+
"""
|
| 146 |
+
return self._begin_index
|
| 147 |
+
|
| 148 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
| 149 |
+
def set_begin_index(self, begin_index: int = 0):
|
| 150 |
+
"""
|
| 151 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
begin_index (`int`):
|
| 155 |
+
The begin index for the scheduler.
|
| 156 |
+
"""
|
| 157 |
+
self._begin_index = begin_index
|
| 158 |
+
|
| 159 |
+
# Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps
|
| 160 |
+
def set_timesteps(
|
| 161 |
+
self,
|
| 162 |
+
num_inference_steps: Union[int, None] = None,
|
| 163 |
+
device: Union[str, torch.device] = None,
|
| 164 |
+
sigmas: Optional[List[float]] = None,
|
| 165 |
+
mu: Optional[Union[float, None]] = None,
|
| 166 |
+
shift: Optional[Union[float, None]] = None,
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
| 170 |
+
Args:
|
| 171 |
+
num_inference_steps (`int`):
|
| 172 |
+
Total number of the spacing of the time steps.
|
| 173 |
+
device (`str` or `torch.device`, *optional*):
|
| 174 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
if self.config.use_dynamic_shifting and mu is None:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
" you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if sigmas is None:
|
| 183 |
+
sigmas = np.linspace(self.sigma_max, self.sigma_min,
|
| 184 |
+
num_inference_steps +
|
| 185 |
+
1).copy()[:-1] # pyright: ignore
|
| 186 |
+
|
| 187 |
+
if self.config.use_dynamic_shifting:
|
| 188 |
+
sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore
|
| 189 |
+
else:
|
| 190 |
+
if shift is None:
|
| 191 |
+
shift = self.config.shift
|
| 192 |
+
sigmas = shift * sigmas / (1 +
|
| 193 |
+
(shift - 1) * sigmas) # pyright: ignore
|
| 194 |
+
|
| 195 |
+
if self.config.final_sigmas_type == "sigma_min":
|
| 196 |
+
sigma_last = ((1 - self.alphas_cumprod[0]) /
|
| 197 |
+
self.alphas_cumprod[0])**0.5
|
| 198 |
+
elif self.config.final_sigmas_type == "zero":
|
| 199 |
+
sigma_last = 0
|
| 200 |
+
else:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
timesteps = sigmas * self.config.num_train_timesteps
|
| 206 |
+
sigmas = np.concatenate([sigmas, [sigma_last]
|
| 207 |
+
]).astype(np.float32) # pyright: ignore
|
| 208 |
+
|
| 209 |
+
self.sigmas = torch.from_numpy(sigmas)
|
| 210 |
+
self.timesteps = torch.from_numpy(timesteps).to(
|
| 211 |
+
device=device, dtype=torch.int64)
|
| 212 |
+
|
| 213 |
+
self.num_inference_steps = len(timesteps)
|
| 214 |
+
|
| 215 |
+
self.model_outputs = [
|
| 216 |
+
None,
|
| 217 |
+
] * self.config.solver_order
|
| 218 |
+
self.lower_order_nums = 0
|
| 219 |
+
self.last_sample = None
|
| 220 |
+
if self.solver_p:
|
| 221 |
+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
|
| 222 |
+
|
| 223 |
+
# add an index counter for schedulers that allow duplicated timesteps
|
| 224 |
+
self._step_index = None
|
| 225 |
+
self._begin_index = None
|
| 226 |
+
self.sigmas = self.sigmas.to(
|
| 227 |
+
"cpu") # to avoid too much CPU/GPU communication
|
| 228 |
+
|
| 229 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
| 230 |
+
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
"""
|
| 232 |
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
| 233 |
+
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
| 234 |
+
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
| 235 |
+
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
| 236 |
+
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
| 237 |
+
|
| 238 |
+
https://arxiv.org/abs/2205.11487
|
| 239 |
+
"""
|
| 240 |
+
dtype = sample.dtype
|
| 241 |
+
batch_size, channels, *remaining_dims = sample.shape
|
| 242 |
+
|
| 243 |
+
if dtype not in (torch.float32, torch.float64):
|
| 244 |
+
sample = sample.float(
|
| 245 |
+
) # upcast for quantile calculation, and clamp not implemented for cpu half
|
| 246 |
+
|
| 247 |
+
# Flatten sample for doing quantile calculation along each image
|
| 248 |
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
| 249 |
+
|
| 250 |
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
| 251 |
+
|
| 252 |
+
s = torch.quantile(
|
| 253 |
+
abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
| 254 |
+
s = torch.clamp(
|
| 255 |
+
s, min=1, max=self.config.sample_max_value
|
| 256 |
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
| 257 |
+
s = s.unsqueeze(
|
| 258 |
+
1) # (batch_size, 1) because clamp will broadcast along dim=0
|
| 259 |
+
sample = torch.clamp(
|
| 260 |
+
sample, -s, s
|
| 261 |
+
) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
| 262 |
+
|
| 263 |
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
| 264 |
+
sample = sample.to(dtype)
|
| 265 |
+
|
| 266 |
+
return sample
|
| 267 |
+
|
| 268 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t
|
| 269 |
+
def _sigma_to_t(self, sigma):
|
| 270 |
+
return sigma * self.config.num_train_timesteps
|
| 271 |
+
|
| 272 |
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
| 273 |
+
return 1 - sigma, sigma
|
| 274 |
+
|
| 275 |
+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps
|
| 276 |
+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
| 277 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma)
|
| 278 |
+
|
| 279 |
+
def convert_model_output(
|
| 280 |
+
self,
|
| 281 |
+
model_output: torch.Tensor,
|
| 282 |
+
*args,
|
| 283 |
+
sample: torch.Tensor = None,
|
| 284 |
+
**kwargs,
|
| 285 |
+
) -> torch.Tensor:
|
| 286 |
+
r"""
|
| 287 |
+
Convert the model output to the corresponding type the UniPC algorithm needs.
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
model_output (`torch.Tensor`):
|
| 291 |
+
The direct output from the learned diffusion model.
|
| 292 |
+
timestep (`int`):
|
| 293 |
+
The current discrete timestep in the diffusion chain.
|
| 294 |
+
sample (`torch.Tensor`):
|
| 295 |
+
A current instance of a sample created by the diffusion process.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
`torch.Tensor`:
|
| 299 |
+
The converted model output.
|
| 300 |
+
"""
|
| 301 |
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
| 302 |
+
if sample is None:
|
| 303 |
+
if len(args) > 1:
|
| 304 |
+
sample = args[1]
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError(
|
| 307 |
+
"missing `sample` as a required keyward argument")
|
| 308 |
+
if timestep is not None:
|
| 309 |
+
deprecate(
|
| 310 |
+
"timesteps",
|
| 311 |
+
"1.0.0",
|
| 312 |
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
sigma = self.sigmas[self.step_index]
|
| 316 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 317 |
+
|
| 318 |
+
if self.predict_x0:
|
| 319 |
+
if self.config.prediction_type == "flow_prediction":
|
| 320 |
+
sigma_t = self.sigmas[self.step_index]
|
| 321 |
+
x0_pred = sample - sigma_t * model_output
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 325 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if self.config.thresholding:
|
| 329 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 330 |
+
|
| 331 |
+
return x0_pred
|
| 332 |
+
else:
|
| 333 |
+
if self.config.prediction_type == "flow_prediction":
|
| 334 |
+
sigma_t = self.sigmas[self.step_index]
|
| 335 |
+
epsilon = sample - (1 - sigma_t) * model_output
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`,"
|
| 339 |
+
" `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler."
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if self.config.thresholding:
|
| 343 |
+
sigma_t = self.sigmas[self.step_index]
|
| 344 |
+
x0_pred = sample - sigma_t * model_output
|
| 345 |
+
x0_pred = self._threshold_sample(x0_pred)
|
| 346 |
+
epsilon = model_output + x0_pred
|
| 347 |
+
|
| 348 |
+
return epsilon
|
| 349 |
+
|
| 350 |
+
def multistep_uni_p_bh_update(
|
| 351 |
+
self,
|
| 352 |
+
model_output: torch.Tensor,
|
| 353 |
+
*args,
|
| 354 |
+
sample: torch.Tensor = None,
|
| 355 |
+
order: int = None, # pyright: ignore
|
| 356 |
+
**kwargs,
|
| 357 |
+
) -> torch.Tensor:
|
| 358 |
+
"""
|
| 359 |
+
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
model_output (`torch.Tensor`):
|
| 363 |
+
The direct output from the learned diffusion model at the current timestep.
|
| 364 |
+
prev_timestep (`int`):
|
| 365 |
+
The previous discrete timestep in the diffusion chain.
|
| 366 |
+
sample (`torch.Tensor`):
|
| 367 |
+
A current instance of a sample created by the diffusion process.
|
| 368 |
+
order (`int`):
|
| 369 |
+
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
`torch.Tensor`:
|
| 373 |
+
The sample tensor at the previous timestep.
|
| 374 |
+
"""
|
| 375 |
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop(
|
| 376 |
+
"prev_timestep", None)
|
| 377 |
+
if sample is None:
|
| 378 |
+
if len(args) > 1:
|
| 379 |
+
sample = args[1]
|
| 380 |
+
else:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
" missing `sample` as a required keyward argument")
|
| 383 |
+
if order is None:
|
| 384 |
+
if len(args) > 2:
|
| 385 |
+
order = args[2]
|
| 386 |
+
else:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
" missing `order` as a required keyward argument")
|
| 389 |
+
if prev_timestep is not None:
|
| 390 |
+
deprecate(
|
| 391 |
+
"prev_timestep",
|
| 392 |
+
"1.0.0",
|
| 393 |
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 394 |
+
)
|
| 395 |
+
model_output_list = self.model_outputs
|
| 396 |
+
|
| 397 |
+
s0 = self.timestep_list[-1]
|
| 398 |
+
m0 = model_output_list[-1]
|
| 399 |
+
x = sample
|
| 400 |
+
|
| 401 |
+
if self.solver_p:
|
| 402 |
+
x_t = self.solver_p.step(model_output, s0, x).prev_sample
|
| 403 |
+
return x_t
|
| 404 |
+
|
| 405 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[
|
| 406 |
+
self.step_index] # pyright: ignore
|
| 407 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 408 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 409 |
+
|
| 410 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 411 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 412 |
+
|
| 413 |
+
h = lambda_t - lambda_s0
|
| 414 |
+
device = sample.device
|
| 415 |
+
|
| 416 |
+
rks = []
|
| 417 |
+
D1s = []
|
| 418 |
+
for i in range(1, order):
|
| 419 |
+
si = self.step_index - i # pyright: ignore
|
| 420 |
+
mi = model_output_list[-(i + 1)]
|
| 421 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 422 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 423 |
+
rk = (lambda_si - lambda_s0) / h
|
| 424 |
+
rks.append(rk)
|
| 425 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 426 |
+
|
| 427 |
+
rks.append(1.0)
|
| 428 |
+
rks = torch.tensor(rks, device=device)
|
| 429 |
+
|
| 430 |
+
R = []
|
| 431 |
+
b = []
|
| 432 |
+
|
| 433 |
+
hh = -h if self.predict_x0 else h
|
| 434 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 435 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 436 |
+
|
| 437 |
+
factorial_i = 1
|
| 438 |
+
|
| 439 |
+
if self.config.solver_type == "bh1":
|
| 440 |
+
B_h = hh
|
| 441 |
+
elif self.config.solver_type == "bh2":
|
| 442 |
+
B_h = torch.expm1(hh)
|
| 443 |
+
else:
|
| 444 |
+
raise NotImplementedError()
|
| 445 |
+
|
| 446 |
+
for i in range(1, order + 1):
|
| 447 |
+
R.append(torch.pow(rks, i - 1))
|
| 448 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 449 |
+
factorial_i *= i + 1
|
| 450 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 451 |
+
|
| 452 |
+
R = torch.stack(R)
|
| 453 |
+
b = torch.tensor(b, device=device)
|
| 454 |
+
|
| 455 |
+
if len(D1s) > 0:
|
| 456 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 457 |
+
# for order 2, we use a simplified version
|
| 458 |
+
if order == 2:
|
| 459 |
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 460 |
+
else:
|
| 461 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1],
|
| 462 |
+
b[:-1]).to(device).to(x.dtype)
|
| 463 |
+
else:
|
| 464 |
+
D1s = None
|
| 465 |
+
|
| 466 |
+
if self.predict_x0:
|
| 467 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 468 |
+
if D1s is not None:
|
| 469 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
|
| 470 |
+
D1s) # pyright: ignore
|
| 471 |
+
else:
|
| 472 |
+
pred_res = 0
|
| 473 |
+
x_t = x_t_ - alpha_t * B_h * pred_res
|
| 474 |
+
else:
|
| 475 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 476 |
+
if D1s is not None:
|
| 477 |
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p,
|
| 478 |
+
D1s) # pyright: ignore
|
| 479 |
+
else:
|
| 480 |
+
pred_res = 0
|
| 481 |
+
x_t = x_t_ - sigma_t * B_h * pred_res
|
| 482 |
+
|
| 483 |
+
x_t = x_t.to(x.dtype)
|
| 484 |
+
return x_t
|
| 485 |
+
|
| 486 |
+
def multistep_uni_c_bh_update(
|
| 487 |
+
self,
|
| 488 |
+
this_model_output: torch.Tensor,
|
| 489 |
+
*args,
|
| 490 |
+
last_sample: torch.Tensor = None,
|
| 491 |
+
this_sample: torch.Tensor = None,
|
| 492 |
+
order: int = None, # pyright: ignore
|
| 493 |
+
**kwargs,
|
| 494 |
+
) -> torch.Tensor:
|
| 495 |
+
"""
|
| 496 |
+
One step for the UniC (B(h) version).
|
| 497 |
+
|
| 498 |
+
Args:
|
| 499 |
+
this_model_output (`torch.Tensor`):
|
| 500 |
+
The model outputs at `x_t`.
|
| 501 |
+
this_timestep (`int`):
|
| 502 |
+
The current timestep `t`.
|
| 503 |
+
last_sample (`torch.Tensor`):
|
| 504 |
+
The generated sample before the last predictor `x_{t-1}`.
|
| 505 |
+
this_sample (`torch.Tensor`):
|
| 506 |
+
The generated sample after the last predictor `x_{t}`.
|
| 507 |
+
order (`int`):
|
| 508 |
+
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
`torch.Tensor`:
|
| 512 |
+
The corrected sample tensor at the current timestep.
|
| 513 |
+
"""
|
| 514 |
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop(
|
| 515 |
+
"this_timestep", None)
|
| 516 |
+
if last_sample is None:
|
| 517 |
+
if len(args) > 1:
|
| 518 |
+
last_sample = args[1]
|
| 519 |
+
else:
|
| 520 |
+
raise ValueError(
|
| 521 |
+
" missing`last_sample` as a required keyward argument")
|
| 522 |
+
if this_sample is None:
|
| 523 |
+
if len(args) > 2:
|
| 524 |
+
this_sample = args[2]
|
| 525 |
+
else:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
" missing`this_sample` as a required keyward argument")
|
| 528 |
+
if order is None:
|
| 529 |
+
if len(args) > 3:
|
| 530 |
+
order = args[3]
|
| 531 |
+
else:
|
| 532 |
+
raise ValueError(
|
| 533 |
+
" missing`order` as a required keyward argument")
|
| 534 |
+
if this_timestep is not None:
|
| 535 |
+
deprecate(
|
| 536 |
+
"this_timestep",
|
| 537 |
+
"1.0.0",
|
| 538 |
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
model_output_list = self.model_outputs
|
| 542 |
+
|
| 543 |
+
m0 = model_output_list[-1]
|
| 544 |
+
x = last_sample
|
| 545 |
+
x_t = this_sample
|
| 546 |
+
model_t = this_model_output
|
| 547 |
+
|
| 548 |
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[
|
| 549 |
+
self.step_index - 1] # pyright: ignore
|
| 550 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
| 551 |
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
| 552 |
+
|
| 553 |
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
| 554 |
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
| 555 |
+
|
| 556 |
+
h = lambda_t - lambda_s0
|
| 557 |
+
device = this_sample.device
|
| 558 |
+
|
| 559 |
+
rks = []
|
| 560 |
+
D1s = []
|
| 561 |
+
for i in range(1, order):
|
| 562 |
+
si = self.step_index - (i + 1) # pyright: ignore
|
| 563 |
+
mi = model_output_list[-(i + 1)]
|
| 564 |
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
| 565 |
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
| 566 |
+
rk = (lambda_si - lambda_s0) / h
|
| 567 |
+
rks.append(rk)
|
| 568 |
+
D1s.append((mi - m0) / rk) # pyright: ignore
|
| 569 |
+
|
| 570 |
+
rks.append(1.0)
|
| 571 |
+
rks = torch.tensor(rks, device=device)
|
| 572 |
+
|
| 573 |
+
R = []
|
| 574 |
+
b = []
|
| 575 |
+
|
| 576 |
+
hh = -h if self.predict_x0 else h
|
| 577 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 578 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 579 |
+
|
| 580 |
+
factorial_i = 1
|
| 581 |
+
|
| 582 |
+
if self.config.solver_type == "bh1":
|
| 583 |
+
B_h = hh
|
| 584 |
+
elif self.config.solver_type == "bh2":
|
| 585 |
+
B_h = torch.expm1(hh)
|
| 586 |
+
else:
|
| 587 |
+
raise NotImplementedError()
|
| 588 |
+
|
| 589 |
+
for i in range(1, order + 1):
|
| 590 |
+
R.append(torch.pow(rks, i - 1))
|
| 591 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 592 |
+
factorial_i *= i + 1
|
| 593 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 594 |
+
|
| 595 |
+
R = torch.stack(R)
|
| 596 |
+
b = torch.tensor(b, device=device)
|
| 597 |
+
|
| 598 |
+
if len(D1s) > 0:
|
| 599 |
+
D1s = torch.stack(D1s, dim=1)
|
| 600 |
+
else:
|
| 601 |
+
D1s = None
|
| 602 |
+
|
| 603 |
+
# for order 1, we use a simplified version
|
| 604 |
+
if order == 1:
|
| 605 |
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
| 606 |
+
else:
|
| 607 |
+
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
| 608 |
+
|
| 609 |
+
if self.predict_x0:
|
| 610 |
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
| 611 |
+
if D1s is not None:
|
| 612 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 613 |
+
else:
|
| 614 |
+
corr_res = 0
|
| 615 |
+
D1_t = model_t - m0
|
| 616 |
+
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 617 |
+
else:
|
| 618 |
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
| 619 |
+
if D1s is not None:
|
| 620 |
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
| 621 |
+
else:
|
| 622 |
+
corr_res = 0
|
| 623 |
+
D1_t = model_t - m0
|
| 624 |
+
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
|
| 625 |
+
x_t = x_t.to(x.dtype)
|
| 626 |
+
return x_t
|
| 627 |
+
|
| 628 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
| 629 |
+
if schedule_timesteps is None:
|
| 630 |
+
schedule_timesteps = self.timesteps
|
| 631 |
+
|
| 632 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
| 633 |
+
|
| 634 |
+
# The sigma index that is taken for the **very** first `step`
|
| 635 |
+
# is always the second index (or the last index if there is only 1)
|
| 636 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
| 637 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
| 638 |
+
pos = 1 if len(indices) > 1 else 0
|
| 639 |
+
|
| 640 |
+
return indices[pos].item()
|
| 641 |
+
|
| 642 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
| 643 |
+
def _init_step_index(self, timestep):
|
| 644 |
+
"""
|
| 645 |
+
Initialize the step_index counter for the scheduler.
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
if self.begin_index is None:
|
| 649 |
+
if isinstance(timestep, torch.Tensor):
|
| 650 |
+
timestep = timestep.to(self.timesteps.device)
|
| 651 |
+
self._step_index = self.index_for_timestep(timestep)
|
| 652 |
+
else:
|
| 653 |
+
self._step_index = self._begin_index
|
| 654 |
+
|
| 655 |
+
def step(self,
|
| 656 |
+
model_output: torch.Tensor,
|
| 657 |
+
timestep: Union[int, torch.Tensor],
|
| 658 |
+
sample: torch.Tensor,
|
| 659 |
+
return_dict: bool = True,
|
| 660 |
+
generator=None) -> Union[SchedulerOutput, Tuple]:
|
| 661 |
+
"""
|
| 662 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
|
| 663 |
+
the multistep UniPC.
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
model_output (`torch.Tensor`):
|
| 667 |
+
The direct output from learned diffusion model.
|
| 668 |
+
timestep (`int`):
|
| 669 |
+
The current discrete timestep in the diffusion chain.
|
| 670 |
+
sample (`torch.Tensor`):
|
| 671 |
+
A current instance of a sample created by the diffusion process.
|
| 672 |
+
return_dict (`bool`):
|
| 673 |
+
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
| 674 |
+
|
| 675 |
+
Returns:
|
| 676 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
| 677 |
+
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
| 678 |
+
tuple is returned where the first element is the sample tensor.
|
| 679 |
+
|
| 680 |
+
"""
|
| 681 |
+
if self.num_inference_steps is None:
|
| 682 |
+
raise ValueError(
|
| 683 |
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
if self.step_index is None:
|
| 687 |
+
self._init_step_index(timestep)
|
| 688 |
+
|
| 689 |
+
use_corrector = (
|
| 690 |
+
self.step_index > 0 and
|
| 691 |
+
self.step_index - 1 not in self.disable_corrector and
|
| 692 |
+
self.last_sample is not None # pyright: ignore
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
model_output_convert = self.convert_model_output(
|
| 696 |
+
model_output, sample=sample)
|
| 697 |
+
if use_corrector:
|
| 698 |
+
sample = self.multistep_uni_c_bh_update(
|
| 699 |
+
this_model_output=model_output_convert,
|
| 700 |
+
last_sample=self.last_sample,
|
| 701 |
+
this_sample=sample,
|
| 702 |
+
order=self.this_order,
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
for i in range(self.config.solver_order - 1):
|
| 706 |
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
| 707 |
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
| 708 |
+
|
| 709 |
+
self.model_outputs[-1] = model_output_convert
|
| 710 |
+
self.timestep_list[-1] = timestep # pyright: ignore
|
| 711 |
+
|
| 712 |
+
if self.config.lower_order_final:
|
| 713 |
+
this_order = min(self.config.solver_order,
|
| 714 |
+
len(self.timesteps) -
|
| 715 |
+
self.step_index) # pyright: ignore
|
| 716 |
+
else:
|
| 717 |
+
this_order = self.config.solver_order
|
| 718 |
+
|
| 719 |
+
self.this_order = min(this_order,
|
| 720 |
+
self.lower_order_nums + 1) # warmup for multistep
|
| 721 |
+
assert self.this_order > 0
|
| 722 |
+
|
| 723 |
+
self.last_sample = sample
|
| 724 |
+
prev_sample = self.multistep_uni_p_bh_update(
|
| 725 |
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
| 726 |
+
sample=sample,
|
| 727 |
+
order=self.this_order,
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
if self.lower_order_nums < self.config.solver_order:
|
| 731 |
+
self.lower_order_nums += 1
|
| 732 |
+
|
| 733 |
+
# upon completion increase step index by one
|
| 734 |
+
self._step_index += 1 # pyright: ignore
|
| 735 |
+
|
| 736 |
+
if not return_dict:
|
| 737 |
+
return (prev_sample,)
|
| 738 |
+
|
| 739 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
| 740 |
+
|
| 741 |
+
def scale_model_input(self, sample: torch.Tensor, *args,
|
| 742 |
+
**kwargs) -> torch.Tensor:
|
| 743 |
+
"""
|
| 744 |
+
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
| 745 |
+
current timestep.
|
| 746 |
+
|
| 747 |
+
Args:
|
| 748 |
+
sample (`torch.Tensor`):
|
| 749 |
+
The input sample.
|
| 750 |
+
|
| 751 |
+
Returns:
|
| 752 |
+
`torch.Tensor`:
|
| 753 |
+
A scaled input sample.
|
| 754 |
+
"""
|
| 755 |
+
return sample
|
| 756 |
+
|
| 757 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
|
| 758 |
+
def add_noise(
|
| 759 |
+
self,
|
| 760 |
+
original_samples: torch.Tensor,
|
| 761 |
+
noise: torch.Tensor,
|
| 762 |
+
timesteps: torch.IntTensor,
|
| 763 |
+
) -> torch.Tensor:
|
| 764 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
| 765 |
+
sigmas = self.sigmas.to(
|
| 766 |
+
device=original_samples.device, dtype=original_samples.dtype)
|
| 767 |
+
if original_samples.device.type == "mps" and torch.is_floating_point(
|
| 768 |
+
timesteps):
|
| 769 |
+
# mps does not support float64
|
| 770 |
+
schedule_timesteps = self.timesteps.to(
|
| 771 |
+
original_samples.device, dtype=torch.float32)
|
| 772 |
+
timesteps = timesteps.to(
|
| 773 |
+
original_samples.device, dtype=torch.float32)
|
| 774 |
+
else:
|
| 775 |
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
| 776 |
+
timesteps = timesteps.to(original_samples.device)
|
| 777 |
+
|
| 778 |
+
# begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index
|
| 779 |
+
if self.begin_index is None:
|
| 780 |
+
step_indices = [
|
| 781 |
+
self.index_for_timestep(t, schedule_timesteps)
|
| 782 |
+
for t in timesteps
|
| 783 |
+
]
|
| 784 |
+
elif self.step_index is not None:
|
| 785 |
+
# add_noise is called after first denoising step (for inpainting)
|
| 786 |
+
step_indices = [self.step_index] * timesteps.shape[0]
|
| 787 |
+
else:
|
| 788 |
+
# add noise is called before first denoising step to create initial latent(img2img)
|
| 789 |
+
step_indices = [self.begin_index] * timesteps.shape[0]
|
| 790 |
+
|
| 791 |
+
sigma = sigmas[step_indices].flatten()
|
| 792 |
+
while len(sigma.shape) < len(original_samples.shape):
|
| 793 |
+
sigma = sigma.unsqueeze(-1)
|
| 794 |
+
|
| 795 |
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
| 796 |
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
| 797 |
+
return noisy_samples
|
| 798 |
+
|
| 799 |
+
def __len__(self):
|
| 800 |
+
return self.config.num_train_timesteps
|
humo/models/utils/utils.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import binascii
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import json
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
|
| 9 |
+
import imageio
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision
|
| 12 |
+
from moviepy.editor import AudioFileClip, VideoClip
|
| 13 |
+
|
| 14 |
+
__all__ = ['tensor_to_video', 'prepare_json_dataset']
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def tensor_to_video(tensor, output_video_path, input_audio_path, fps=25):
|
| 18 |
+
"""
|
| 19 |
+
Converts a Tensor with shape [c, f, h, w] into a video and adds an audio track from the specified audio file.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
tensor (numpy): The Tensor to be converted, shaped [f, h, w, c].
|
| 23 |
+
output_video_path (str): The file path where the output video will be saved.
|
| 24 |
+
input_audio_path (str): The path to the audio file (WAV file) that contains the audio track to be added.
|
| 25 |
+
fps (int): The frame rate of the output video. Default is 30 fps.
|
| 26 |
+
"""
|
| 27 |
+
def make_frame(t):
|
| 28 |
+
frame_index = min(int(t * fps), tensor.shape[0] - 1)
|
| 29 |
+
return tensor[frame_index]
|
| 30 |
+
|
| 31 |
+
video_duration = tensor.shape[0] / fps
|
| 32 |
+
audio_clip = AudioFileClip(input_audio_path)
|
| 33 |
+
audio_duration = audio_clip.duration
|
| 34 |
+
final_duration = min(video_duration, audio_duration)
|
| 35 |
+
audio_clip = audio_clip.subclip(0, final_duration)
|
| 36 |
+
new_video_clip = VideoClip(make_frame, duration=final_duration)
|
| 37 |
+
new_video_clip = new_video_clip.set_audio(audio_clip)
|
| 38 |
+
new_video_clip.write_videofile(output_video_path, fps=fps, audio_codec="aac")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def prepare_json_dataset(json_path):
|
| 42 |
+
samples = []
|
| 43 |
+
with open(json_path, "rb") as f:
|
| 44 |
+
data = json.load(f)
|
| 45 |
+
for itemname, row in data.items():
|
| 46 |
+
text = row['prompt'].strip().replace("_", " ").strip('"')
|
| 47 |
+
audio_path = row['audio_path']
|
| 48 |
+
ref_img_path = [x for x in row['img_paths']]
|
| 49 |
+
|
| 50 |
+
samples.append({
|
| 51 |
+
"text": text,
|
| 52 |
+
"ref_img": ref_img_path,
|
| 53 |
+
"audio": audio_path,
|
| 54 |
+
"itemname": itemname
|
| 55 |
+
})
|
| 56 |
+
samples = OmegaConf.create(samples)
|
| 57 |
+
|
| 58 |
+
return samples
|
humo/models/wan_modules/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import flash_attention
|
| 2 |
+
from .model import WanModel
|
| 3 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
| 4 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 5 |
+
from .vae import WanVAE
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
'WanVAE',
|
| 9 |
+
'WanModel',
|
| 10 |
+
'T5Model',
|
| 11 |
+
'T5Encoder',
|
| 12 |
+
'T5Decoder',
|
| 13 |
+
'T5EncoderModel',
|
| 14 |
+
'HuggingfaceTokenizer',
|
| 15 |
+
'flash_attention',
|
| 16 |
+
]
|
humo/models/wan_modules/attention.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import warnings
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import flash_attn_interface
|
| 8 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 9 |
+
except ModuleNotFoundError:
|
| 10 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import flash_attn
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 15 |
+
except ModuleNotFoundError:
|
| 16 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'flash_attention',
|
| 21 |
+
'attention',
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------
|
| 26 |
+
# Custom op + fake kernel
|
| 27 |
+
# ---------------------------
|
| 28 |
+
from typing import Optional, Sequence # <- add Sequence
|
| 29 |
+
|
| 30 |
+
# ... imports unchanged ...
|
| 31 |
+
from typing import Optional, Sequence
|
| 32 |
+
|
| 33 |
+
@torch.library.custom_op("wan::flash_attention", mutates_args=())
|
| 34 |
+
def _wan_flash_attention_op(
|
| 35 |
+
q: torch.Tensor,
|
| 36 |
+
k: torch.Tensor,
|
| 37 |
+
v: torch.Tensor,
|
| 38 |
+
q_lens: Optional[torch.Tensor] = None,
|
| 39 |
+
k_lens: Optional[torch.Tensor] = None,
|
| 40 |
+
dropout_p: float = 0.0,
|
| 41 |
+
softmax_scale: Optional[float] = None,
|
| 42 |
+
q_scale: Optional[float] = None,
|
| 43 |
+
causal: bool = False,
|
| 44 |
+
# IMPORTANT: schema-friendly default (None), not a tuple
|
| 45 |
+
window_size: Optional[Sequence[int]] = None,
|
| 46 |
+
deterministic: bool = False,
|
| 47 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 48 |
+
version: Optional[int] = None,
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 51 |
+
assert dtype in half_dtypes
|
| 52 |
+
assert q.size(-1) <= 256
|
| 53 |
+
|
| 54 |
+
# normalize window_size to a 2-tuple for FA2 API
|
| 55 |
+
if window_size is None:
|
| 56 |
+
ws = (-1, -1)
|
| 57 |
+
else:
|
| 58 |
+
ws = tuple(window_size)
|
| 59 |
+
if len(ws) != 2:
|
| 60 |
+
raise ValueError(f"window_size must have length 2; got {window_size!r}")
|
| 61 |
+
|
| 62 |
+
b, lq, nheads = q.shape[0], q.shape[1], q.shape[2]
|
| 63 |
+
lk = k.shape[1]
|
| 64 |
+
out_dtype = q.dtype
|
| 65 |
+
|
| 66 |
+
def half(x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 68 |
+
|
| 69 |
+
# --- preprocess (unchanged) ---
|
| 70 |
+
if q_lens is None:
|
| 71 |
+
q_flat = half(q.flatten(0, 1))
|
| 72 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32)
|
| 73 |
+
else:
|
| 74 |
+
q_flat = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 75 |
+
|
| 76 |
+
if k_lens is None:
|
| 77 |
+
k_flat = half(k.flatten(0, 1))
|
| 78 |
+
v_flat = half(v.flatten(0, 1))
|
| 79 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32)
|
| 80 |
+
else:
|
| 81 |
+
k_flat = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 82 |
+
v_flat = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 83 |
+
|
| 84 |
+
q_flat = q_flat.to(v_flat.dtype); k_flat = k_flat.to(v_flat.dtype)
|
| 85 |
+
if q_scale is not None:
|
| 86 |
+
q_flat = q_flat * q_scale
|
| 87 |
+
|
| 88 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 89 |
+
warnings.warn('Flash attention 3 is not available, use flash attention 2 instead.')
|
| 90 |
+
|
| 91 |
+
if FLASH_ATTN_3_AVAILABLE:
|
| 92 |
+
ret = flash_attn_interface.flash_attn_varlen_func(
|
| 93 |
+
q=q_flat,
|
| 94 |
+
k=k_flat,
|
| 95 |
+
v=v_flat,
|
| 96 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
|
| 97 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(k_flat.device, non_blocking=True),
|
| 98 |
+
seqused_q=None,
|
| 99 |
+
seqused_k=None,
|
| 100 |
+
max_seqlen_q=lq,
|
| 101 |
+
max_seqlen_k=lk,
|
| 102 |
+
softmax_scale=softmax_scale,
|
| 103 |
+
causal=causal,
|
| 104 |
+
deterministic=deterministic,
|
| 105 |
+
)
|
| 106 |
+
out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
|
| 107 |
+
total_q = b * lq
|
| 108 |
+
if out0.dim() != 3:
|
| 109 |
+
raise RuntimeError(f"Unexpected FA3 output rank {out0.dim()} shape={tuple(out0.shape)}")
|
| 110 |
+
if out0.shape[0] == total_q:
|
| 111 |
+
out_flat = out0
|
| 112 |
+
elif out0.shape[0] == nheads and out0.shape[1] == total_q:
|
| 113 |
+
out_flat = out0.transpose(0, 1).contiguous()
|
| 114 |
+
else:
|
| 115 |
+
raise RuntimeError(f"Unexpected FA3 output shape {tuple(out0.shape)}")
|
| 116 |
+
out = out_flat.unflatten(0, (b, lq))
|
| 117 |
+
|
| 118 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 119 |
+
out = flash_attn.flash_attn_varlen_func(
|
| 120 |
+
q=q_flat,
|
| 121 |
+
k=k_flat,
|
| 122 |
+
v=v_flat,
|
| 123 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
|
| 124 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q_flat.device, non_blocking=True),
|
| 125 |
+
max_seqlen_q=lq,
|
| 126 |
+
max_seqlen_k=lk,
|
| 127 |
+
dropout_p=dropout_p,
|
| 128 |
+
softmax_scale=softmax_scale,
|
| 129 |
+
causal=causal,
|
| 130 |
+
window_size=ws, # <- pass 2-tuple
|
| 131 |
+
deterministic=deterministic,
|
| 132 |
+
).unflatten(0, (b, lq))
|
| 133 |
+
else:
|
| 134 |
+
q_s = q.transpose(1, 2).to(dtype)
|
| 135 |
+
k_s = k.transpose(1, 2).to(dtype)
|
| 136 |
+
v_s = v.transpose(1, 2).to(dtype)
|
| 137 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 138 |
+
q_s, k_s, v_s, attn_mask=None, is_causal=causal, dropout_p=dropout_p
|
| 139 |
+
).transpose(1, 2).contiguous()
|
| 140 |
+
|
| 141 |
+
return out.to(out_dtype)
|
| 142 |
+
|
| 143 |
+
@_wan_flash_attention_op.register_fake
|
| 144 |
+
def _wan_flash_attention_op_fake(
|
| 145 |
+
q,
|
| 146 |
+
k,
|
| 147 |
+
v,
|
| 148 |
+
q_lens=None,
|
| 149 |
+
k_lens=None,
|
| 150 |
+
dropout_p: float = 0.0,
|
| 151 |
+
softmax_scale=None,
|
| 152 |
+
q_scale=None,
|
| 153 |
+
causal: bool = False,
|
| 154 |
+
window_size: Optional[Sequence[int]] = None,
|
| 155 |
+
deterministic: bool = False,
|
| 156 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 157 |
+
version: Optional[int] = None,
|
| 158 |
+
):
|
| 159 |
+
# Match output shape: (B, Lq, Nq, Dh_v) and keep the SAME fake device as `q`
|
| 160 |
+
B, Lq, Nq, _ = q.shape
|
| 161 |
+
Dh_v = v.shape[-1]
|
| 162 |
+
return q.new_empty((B, Lq, Nq, Dh_v), dtype=q.dtype)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
# ---------------------------
|
| 167 |
+
# Public API (unchanged signature)
|
| 168 |
+
# ---------------------------
|
| 169 |
+
def flash_attention(
|
| 170 |
+
q,
|
| 171 |
+
k,
|
| 172 |
+
v,
|
| 173 |
+
q_lens=None,
|
| 174 |
+
k_lens=None,
|
| 175 |
+
dropout_p=0.,
|
| 176 |
+
softmax_scale=None,
|
| 177 |
+
q_scale=None,
|
| 178 |
+
causal=False,
|
| 179 |
+
window_size=(-1, -1),
|
| 180 |
+
deterministic=False,
|
| 181 |
+
dtype=torch.bfloat16,
|
| 182 |
+
version=None,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
q: [B, Lq, Nq, C1].
|
| 186 |
+
k: [B, Lk, Nk, C1].
|
| 187 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 188 |
+
q_lens: [B].
|
| 189 |
+
k_lens: [B].
|
| 190 |
+
dropout_p: float. Dropout probability.
|
| 191 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 192 |
+
causal: bool. Whether to apply causal attention mask.
|
| 193 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 194 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 195 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 196 |
+
"""
|
| 197 |
+
# Simply delegate to the custom op so Dynamo/AOT treats it as a single node;
|
| 198 |
+
# our eager kernel inside _wan_flash_attention_op keeps the original behavior.
|
| 199 |
+
return _wan_flash_attention_op(
|
| 200 |
+
q, k, v,
|
| 201 |
+
q_lens=q_lens,
|
| 202 |
+
k_lens=k_lens,
|
| 203 |
+
dropout_p=dropout_p,
|
| 204 |
+
softmax_scale=softmax_scale,
|
| 205 |
+
q_scale=q_scale,
|
| 206 |
+
causal=causal,
|
| 207 |
+
window_size=window_size,
|
| 208 |
+
deterministic=deterministic,
|
| 209 |
+
dtype=dtype,
|
| 210 |
+
version=version,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def attention(
|
| 215 |
+
q,
|
| 216 |
+
k,
|
| 217 |
+
v,
|
| 218 |
+
q_lens=None,
|
| 219 |
+
k_lens=None,
|
| 220 |
+
dropout_p=0.,
|
| 221 |
+
softmax_scale=None,
|
| 222 |
+
q_scale=None,
|
| 223 |
+
causal=False,
|
| 224 |
+
window_size=(-1, -1),
|
| 225 |
+
deterministic=False,
|
| 226 |
+
dtype=torch.bfloat16,
|
| 227 |
+
fa_version=None,
|
| 228 |
+
):
|
| 229 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 230 |
+
return flash_attention(
|
| 231 |
+
q=q,
|
| 232 |
+
k=k,
|
| 233 |
+
v=v,
|
| 234 |
+
q_lens=q_lens,
|
| 235 |
+
k_lens=k_lens,
|
| 236 |
+
dropout_p=dropout_p,
|
| 237 |
+
softmax_scale=softmax_scale,
|
| 238 |
+
q_scale=q_scale,
|
| 239 |
+
causal=causal,
|
| 240 |
+
window_size=window_size,
|
| 241 |
+
deterministic=deterministic,
|
| 242 |
+
dtype=dtype,
|
| 243 |
+
version=fa_version,
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
if q_lens is not None or k_lens is not None:
|
| 247 |
+
warnings.warn(
|
| 248 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 249 |
+
)
|
| 250 |
+
q_ = q.transpose(1, 2).to(dtype)
|
| 251 |
+
k_ = k.transpose(1, 2).to(dtype)
|
| 252 |
+
v_ = v.transpose(1, 2).to(dtype)
|
| 253 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 254 |
+
q_, k_, v_, attn_mask=None, is_causal=causal, dropout_p=dropout_p
|
| 255 |
+
)
|
| 256 |
+
return out.transpose(1, 2).contiguous()
|
humo/models/wan_modules/clip.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
|
| 11 |
+
from .attention import flash_attention
|
| 12 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 13 |
+
from .xlm_roberta import XLMRoberta
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'XLMRobertaCLIP',
|
| 17 |
+
'clip_xlm_roberta_vit_h_14',
|
| 18 |
+
'CLIPModel',
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def pos_interpolate(pos, seq_len):
|
| 23 |
+
if pos.size(1) == seq_len:
|
| 24 |
+
return pos
|
| 25 |
+
else:
|
| 26 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 27 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 28 |
+
n = pos.size(1) - src_grid * src_grid
|
| 29 |
+
return torch.cat([
|
| 30 |
+
pos[:, :n],
|
| 31 |
+
F.interpolate(
|
| 32 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
| 33 |
+
0, 3, 1, 2),
|
| 34 |
+
size=(tar_grid, tar_grid),
|
| 35 |
+
mode='bicubic',
|
| 36 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
| 37 |
+
],
|
| 38 |
+
dim=1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class QuickGELU(nn.Module):
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return x * torch.sigmoid(1.702 * x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LayerNorm(nn.LayerNorm):
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return super().forward(x.float()).type_as(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SelfAttention(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self,
|
| 56 |
+
dim,
|
| 57 |
+
num_heads,
|
| 58 |
+
causal=False,
|
| 59 |
+
attn_dropout=0.0,
|
| 60 |
+
proj_dropout=0.0):
|
| 61 |
+
assert dim % num_heads == 0
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
self.head_dim = dim // num_heads
|
| 66 |
+
self.causal = causal
|
| 67 |
+
self.attn_dropout = attn_dropout
|
| 68 |
+
self.proj_dropout = proj_dropout
|
| 69 |
+
|
| 70 |
+
# layers
|
| 71 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 72 |
+
self.proj = nn.Linear(dim, dim)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
"""
|
| 76 |
+
x: [B, L, C].
|
| 77 |
+
"""
|
| 78 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 79 |
+
|
| 80 |
+
# compute query, key, value
|
| 81 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
| 82 |
+
|
| 83 |
+
# compute attention
|
| 84 |
+
p = self.attn_dropout if self.training else 0.0
|
| 85 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
| 86 |
+
x = x.reshape(b, s, c)
|
| 87 |
+
|
| 88 |
+
# output
|
| 89 |
+
x = self.proj(x)
|
| 90 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SwiGLU(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self, dim, mid_dim):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.mid_dim = mid_dim
|
| 100 |
+
|
| 101 |
+
# layers
|
| 102 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 103 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 104 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 108 |
+
x = self.fc3(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AttentionBlock(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self,
|
| 115 |
+
dim,
|
| 116 |
+
mlp_ratio,
|
| 117 |
+
num_heads,
|
| 118 |
+
post_norm=False,
|
| 119 |
+
causal=False,
|
| 120 |
+
activation='quick_gelu',
|
| 121 |
+
attn_dropout=0.0,
|
| 122 |
+
proj_dropout=0.0,
|
| 123 |
+
norm_eps=1e-5):
|
| 124 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dim = dim
|
| 127 |
+
self.mlp_ratio = mlp_ratio
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.post_norm = post_norm
|
| 130 |
+
self.causal = causal
|
| 131 |
+
self.norm_eps = norm_eps
|
| 132 |
+
|
| 133 |
+
# layers
|
| 134 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 135 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
| 136 |
+
proj_dropout)
|
| 137 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 138 |
+
if activation == 'swi_glu':
|
| 139 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 140 |
+
else:
|
| 141 |
+
self.mlp = nn.Sequential(
|
| 142 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 143 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 144 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
if self.post_norm:
|
| 148 |
+
x = x + self.norm1(self.attn(x))
|
| 149 |
+
x = x + self.norm2(self.mlp(x))
|
| 150 |
+
else:
|
| 151 |
+
x = x + self.attn(self.norm1(x))
|
| 152 |
+
x = x + self.mlp(self.norm2(x))
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class AttentionPool(nn.Module):
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
dim,
|
| 160 |
+
mlp_ratio,
|
| 161 |
+
num_heads,
|
| 162 |
+
activation='gelu',
|
| 163 |
+
proj_dropout=0.0,
|
| 164 |
+
norm_eps=1e-5):
|
| 165 |
+
assert dim % num_heads == 0
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.dim = dim
|
| 168 |
+
self.mlp_ratio = mlp_ratio
|
| 169 |
+
self.num_heads = num_heads
|
| 170 |
+
self.head_dim = dim // num_heads
|
| 171 |
+
self.proj_dropout = proj_dropout
|
| 172 |
+
self.norm_eps = norm_eps
|
| 173 |
+
|
| 174 |
+
# layers
|
| 175 |
+
gain = 1.0 / math.sqrt(dim)
|
| 176 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 177 |
+
self.to_q = nn.Linear(dim, dim)
|
| 178 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 179 |
+
self.proj = nn.Linear(dim, dim)
|
| 180 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 181 |
+
self.mlp = nn.Sequential(
|
| 182 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 183 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 184 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
"""
|
| 188 |
+
x: [B, L, C].
|
| 189 |
+
"""
|
| 190 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 191 |
+
|
| 192 |
+
# compute query, key, value
|
| 193 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
| 194 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
| 195 |
+
|
| 196 |
+
# compute attention
|
| 197 |
+
x = flash_attention(q, k, v, version=2)
|
| 198 |
+
x = x.reshape(b, 1, c)
|
| 199 |
+
|
| 200 |
+
# output
|
| 201 |
+
x = self.proj(x)
|
| 202 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 203 |
+
|
| 204 |
+
# mlp
|
| 205 |
+
x = x + self.mlp(self.norm(x))
|
| 206 |
+
return x[:, 0]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class VisionTransformer(nn.Module):
|
| 210 |
+
|
| 211 |
+
def __init__(self,
|
| 212 |
+
image_size=224,
|
| 213 |
+
patch_size=16,
|
| 214 |
+
dim=768,
|
| 215 |
+
mlp_ratio=4,
|
| 216 |
+
out_dim=512,
|
| 217 |
+
num_heads=12,
|
| 218 |
+
num_layers=12,
|
| 219 |
+
pool_type='token',
|
| 220 |
+
pre_norm=True,
|
| 221 |
+
post_norm=False,
|
| 222 |
+
activation='quick_gelu',
|
| 223 |
+
attn_dropout=0.0,
|
| 224 |
+
proj_dropout=0.0,
|
| 225 |
+
embedding_dropout=0.0,
|
| 226 |
+
norm_eps=1e-5):
|
| 227 |
+
if image_size % patch_size != 0:
|
| 228 |
+
print(
|
| 229 |
+
'[WARNING] image_size is not divisible by patch_size',
|
| 230 |
+
flush=True)
|
| 231 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
| 232 |
+
out_dim = out_dim or dim
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.image_size = image_size
|
| 235 |
+
self.patch_size = patch_size
|
| 236 |
+
self.num_patches = (image_size // patch_size)**2
|
| 237 |
+
self.dim = dim
|
| 238 |
+
self.mlp_ratio = mlp_ratio
|
| 239 |
+
self.out_dim = out_dim
|
| 240 |
+
self.num_heads = num_heads
|
| 241 |
+
self.num_layers = num_layers
|
| 242 |
+
self.pool_type = pool_type
|
| 243 |
+
self.post_norm = post_norm
|
| 244 |
+
self.norm_eps = norm_eps
|
| 245 |
+
|
| 246 |
+
# embeddings
|
| 247 |
+
gain = 1.0 / math.sqrt(dim)
|
| 248 |
+
self.patch_embedding = nn.Conv2d(
|
| 249 |
+
3,
|
| 250 |
+
dim,
|
| 251 |
+
kernel_size=patch_size,
|
| 252 |
+
stride=patch_size,
|
| 253 |
+
bias=not pre_norm)
|
| 254 |
+
if pool_type in ('token', 'token_fc'):
|
| 255 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 256 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
| 257 |
+
1, self.num_patches +
|
| 258 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
| 259 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 260 |
+
|
| 261 |
+
# transformer
|
| 262 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 263 |
+
self.transformer = nn.Sequential(*[
|
| 264 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
| 265 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
| 266 |
+
for _ in range(num_layers)
|
| 267 |
+
])
|
| 268 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 269 |
+
|
| 270 |
+
# head
|
| 271 |
+
if pool_type == 'token':
|
| 272 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 273 |
+
elif pool_type == 'token_fc':
|
| 274 |
+
self.head = nn.Linear(dim, out_dim)
|
| 275 |
+
elif pool_type == 'attn_pool':
|
| 276 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
| 277 |
+
proj_dropout, norm_eps)
|
| 278 |
+
|
| 279 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 280 |
+
b = x.size(0)
|
| 281 |
+
|
| 282 |
+
# embeddings
|
| 283 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 284 |
+
if self.pool_type in ('token', 'token_fc'):
|
| 285 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
| 286 |
+
if interpolation:
|
| 287 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 288 |
+
else:
|
| 289 |
+
e = self.pos_embedding
|
| 290 |
+
x = self.dropout(x + e)
|
| 291 |
+
if self.pre_norm is not None:
|
| 292 |
+
x = self.pre_norm(x)
|
| 293 |
+
|
| 294 |
+
# transformer
|
| 295 |
+
if use_31_block:
|
| 296 |
+
x = self.transformer[:-1](x)
|
| 297 |
+
return x
|
| 298 |
+
else:
|
| 299 |
+
x = self.transformer(x)
|
| 300 |
+
return x
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 304 |
+
|
| 305 |
+
def __init__(self, **kwargs):
|
| 306 |
+
self.out_dim = kwargs.pop('out_dim')
|
| 307 |
+
super().__init__(**kwargs)
|
| 308 |
+
|
| 309 |
+
# head
|
| 310 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 311 |
+
self.head = nn.Sequential(
|
| 312 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
| 313 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
| 314 |
+
|
| 315 |
+
def forward(self, ids):
|
| 316 |
+
# xlm-roberta
|
| 317 |
+
x = super().forward(ids)
|
| 318 |
+
|
| 319 |
+
# average pooling
|
| 320 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 321 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 322 |
+
|
| 323 |
+
# head
|
| 324 |
+
x = self.head(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class XLMRobertaCLIP(nn.Module):
|
| 329 |
+
|
| 330 |
+
def __init__(self,
|
| 331 |
+
embed_dim=1024,
|
| 332 |
+
image_size=224,
|
| 333 |
+
patch_size=14,
|
| 334 |
+
vision_dim=1280,
|
| 335 |
+
vision_mlp_ratio=4,
|
| 336 |
+
vision_heads=16,
|
| 337 |
+
vision_layers=32,
|
| 338 |
+
vision_pool='token',
|
| 339 |
+
vision_pre_norm=True,
|
| 340 |
+
vision_post_norm=False,
|
| 341 |
+
activation='gelu',
|
| 342 |
+
vocab_size=250002,
|
| 343 |
+
max_text_len=514,
|
| 344 |
+
type_size=1,
|
| 345 |
+
pad_id=1,
|
| 346 |
+
text_dim=1024,
|
| 347 |
+
text_heads=16,
|
| 348 |
+
text_layers=24,
|
| 349 |
+
text_post_norm=True,
|
| 350 |
+
text_dropout=0.1,
|
| 351 |
+
attn_dropout=0.0,
|
| 352 |
+
proj_dropout=0.0,
|
| 353 |
+
embedding_dropout=0.0,
|
| 354 |
+
norm_eps=1e-5):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.embed_dim = embed_dim
|
| 357 |
+
self.image_size = image_size
|
| 358 |
+
self.patch_size = patch_size
|
| 359 |
+
self.vision_dim = vision_dim
|
| 360 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 361 |
+
self.vision_heads = vision_heads
|
| 362 |
+
self.vision_layers = vision_layers
|
| 363 |
+
self.vision_pre_norm = vision_pre_norm
|
| 364 |
+
self.vision_post_norm = vision_post_norm
|
| 365 |
+
self.activation = activation
|
| 366 |
+
self.vocab_size = vocab_size
|
| 367 |
+
self.max_text_len = max_text_len
|
| 368 |
+
self.type_size = type_size
|
| 369 |
+
self.pad_id = pad_id
|
| 370 |
+
self.text_dim = text_dim
|
| 371 |
+
self.text_heads = text_heads
|
| 372 |
+
self.text_layers = text_layers
|
| 373 |
+
self.text_post_norm = text_post_norm
|
| 374 |
+
self.norm_eps = norm_eps
|
| 375 |
+
|
| 376 |
+
# models
|
| 377 |
+
self.visual = VisionTransformer(
|
| 378 |
+
image_size=image_size,
|
| 379 |
+
patch_size=patch_size,
|
| 380 |
+
dim=vision_dim,
|
| 381 |
+
mlp_ratio=vision_mlp_ratio,
|
| 382 |
+
out_dim=embed_dim,
|
| 383 |
+
num_heads=vision_heads,
|
| 384 |
+
num_layers=vision_layers,
|
| 385 |
+
pool_type=vision_pool,
|
| 386 |
+
pre_norm=vision_pre_norm,
|
| 387 |
+
post_norm=vision_post_norm,
|
| 388 |
+
activation=activation,
|
| 389 |
+
attn_dropout=attn_dropout,
|
| 390 |
+
proj_dropout=proj_dropout,
|
| 391 |
+
embedding_dropout=embedding_dropout,
|
| 392 |
+
norm_eps=norm_eps)
|
| 393 |
+
self.textual = XLMRobertaWithHead(
|
| 394 |
+
vocab_size=vocab_size,
|
| 395 |
+
max_seq_len=max_text_len,
|
| 396 |
+
type_size=type_size,
|
| 397 |
+
pad_id=pad_id,
|
| 398 |
+
dim=text_dim,
|
| 399 |
+
out_dim=embed_dim,
|
| 400 |
+
num_heads=text_heads,
|
| 401 |
+
num_layers=text_layers,
|
| 402 |
+
post_norm=text_post_norm,
|
| 403 |
+
dropout=text_dropout)
|
| 404 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 405 |
+
|
| 406 |
+
def forward(self, imgs, txt_ids):
|
| 407 |
+
"""
|
| 408 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 409 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 410 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 411 |
+
txt_ids: [B, L] of torch.long.
|
| 412 |
+
Encoded by data.CLIPTokenizer.
|
| 413 |
+
"""
|
| 414 |
+
xi = self.visual(imgs)
|
| 415 |
+
xt = self.textual(txt_ids)
|
| 416 |
+
return xi, xt
|
| 417 |
+
|
| 418 |
+
def param_groups(self):
|
| 419 |
+
groups = [{
|
| 420 |
+
'params': [
|
| 421 |
+
p for n, p in self.named_parameters()
|
| 422 |
+
if 'norm' in n or n.endswith('bias')
|
| 423 |
+
],
|
| 424 |
+
'weight_decay': 0.0
|
| 425 |
+
}, {
|
| 426 |
+
'params': [
|
| 427 |
+
p for n, p in self.named_parameters()
|
| 428 |
+
if not ('norm' in n or n.endswith('bias'))
|
| 429 |
+
]
|
| 430 |
+
}]
|
| 431 |
+
return groups
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _clip(pretrained=False,
|
| 435 |
+
pretrained_name=None,
|
| 436 |
+
model_cls=XLMRobertaCLIP,
|
| 437 |
+
return_transforms=False,
|
| 438 |
+
return_tokenizer=False,
|
| 439 |
+
tokenizer_padding='eos',
|
| 440 |
+
dtype=torch.float32,
|
| 441 |
+
device='cpu',
|
| 442 |
+
**kwargs):
|
| 443 |
+
# init a model on device
|
| 444 |
+
with torch.device(device):
|
| 445 |
+
model = model_cls(**kwargs)
|
| 446 |
+
|
| 447 |
+
# set device
|
| 448 |
+
model = model.to(dtype=dtype, device=device)
|
| 449 |
+
output = (model,)
|
| 450 |
+
|
| 451 |
+
# init transforms
|
| 452 |
+
if return_transforms:
|
| 453 |
+
# mean and std
|
| 454 |
+
if 'siglip' in pretrained_name.lower():
|
| 455 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 456 |
+
else:
|
| 457 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 458 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 459 |
+
|
| 460 |
+
# transforms
|
| 461 |
+
transforms = T.Compose([
|
| 462 |
+
T.Resize((model.image_size, model.image_size),
|
| 463 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 464 |
+
T.ToTensor(),
|
| 465 |
+
T.Normalize(mean=mean, std=std)
|
| 466 |
+
])
|
| 467 |
+
output += (transforms,)
|
| 468 |
+
return output[0] if len(output) == 1 else output
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def clip_xlm_roberta_vit_h_14(
|
| 472 |
+
pretrained=False,
|
| 473 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
| 474 |
+
**kwargs):
|
| 475 |
+
cfg = dict(
|
| 476 |
+
embed_dim=1024,
|
| 477 |
+
image_size=224,
|
| 478 |
+
patch_size=14,
|
| 479 |
+
vision_dim=1280,
|
| 480 |
+
vision_mlp_ratio=4,
|
| 481 |
+
vision_heads=16,
|
| 482 |
+
vision_layers=32,
|
| 483 |
+
vision_pool='token',
|
| 484 |
+
activation='gelu',
|
| 485 |
+
vocab_size=250002,
|
| 486 |
+
max_text_len=514,
|
| 487 |
+
type_size=1,
|
| 488 |
+
pad_id=1,
|
| 489 |
+
text_dim=1024,
|
| 490 |
+
text_heads=16,
|
| 491 |
+
text_layers=24,
|
| 492 |
+
text_post_norm=True,
|
| 493 |
+
text_dropout=0.1,
|
| 494 |
+
attn_dropout=0.0,
|
| 495 |
+
proj_dropout=0.0,
|
| 496 |
+
embedding_dropout=0.0)
|
| 497 |
+
cfg.update(**kwargs)
|
| 498 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class CLIPModel:
|
| 502 |
+
|
| 503 |
+
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
|
| 504 |
+
self.dtype = dtype
|
| 505 |
+
self.device = device
|
| 506 |
+
self.checkpoint_path = checkpoint_path
|
| 507 |
+
self.tokenizer_path = tokenizer_path
|
| 508 |
+
|
| 509 |
+
# init model
|
| 510 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 511 |
+
pretrained=False,
|
| 512 |
+
return_transforms=True,
|
| 513 |
+
return_tokenizer=False,
|
| 514 |
+
dtype=dtype,
|
| 515 |
+
device=device)
|
| 516 |
+
self.model = self.model.eval().requires_grad_(False)
|
| 517 |
+
logging.info(f'loading {checkpoint_path}')
|
| 518 |
+
self.model.load_state_dict(
|
| 519 |
+
torch.load(checkpoint_path, map_location='cpu'))
|
| 520 |
+
|
| 521 |
+
# init tokenizer
|
| 522 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 523 |
+
name=tokenizer_path,
|
| 524 |
+
seq_len=self.model.max_text_len - 2,
|
| 525 |
+
clean='whitespace')
|
| 526 |
+
|
| 527 |
+
def visual(self, videos):
|
| 528 |
+
# preprocess
|
| 529 |
+
size = (self.model.image_size,) * 2
|
| 530 |
+
videos = torch.cat([
|
| 531 |
+
F.interpolate(
|
| 532 |
+
u.transpose(0, 1),
|
| 533 |
+
size=size,
|
| 534 |
+
mode='bicubic',
|
| 535 |
+
align_corners=False) for u in videos
|
| 536 |
+
])
|
| 537 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 538 |
+
|
| 539 |
+
# forward
|
| 540 |
+
with torch.amp.autocast('cuda', dtype=self.dtype):
|
| 541 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 542 |
+
return out
|
humo/models/wan_modules/model.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
+
|
| 10 |
+
from .attention import flash_attention
|
| 11 |
+
|
| 12 |
+
__all__ = ['WanModel']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 16 |
+
# preprocess
|
| 17 |
+
assert dim % 2 == 0
|
| 18 |
+
half = dim // 2
|
| 19 |
+
position = position.type(torch.float64)
|
| 20 |
+
|
| 21 |
+
# calculation
|
| 22 |
+
sinusoid = torch.outer(
|
| 23 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 24 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 29 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 30 |
+
assert dim % 2 == 0
|
| 31 |
+
freqs = torch.outer(
|
| 32 |
+
torch.arange(max_seq_len),
|
| 33 |
+
1.0 / torch.pow(theta,
|
| 34 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 35 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 36 |
+
return freqs
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@torch.amp.autocast("cuda", enabled=False)
|
| 40 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 41 |
+
n, c = x.size(2), x.size(3) // 2
|
| 42 |
+
|
| 43 |
+
# split freqs
|
| 44 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 45 |
+
|
| 46 |
+
# loop over samples
|
| 47 |
+
output = []
|
| 48 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 49 |
+
seq_len = f * h * w
|
| 50 |
+
|
| 51 |
+
# precompute multipliers
|
| 52 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
| 53 |
+
seq_len, n, -1, 2))
|
| 54 |
+
freqs_i = torch.cat([
|
| 55 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 56 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 57 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 58 |
+
],
|
| 59 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 60 |
+
|
| 61 |
+
# apply rotary embedding
|
| 62 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 63 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 64 |
+
|
| 65 |
+
# append to collection
|
| 66 |
+
output.append(x_i)
|
| 67 |
+
return torch.stack(output).float()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class WanRMSNorm(nn.Module):
|
| 71 |
+
|
| 72 |
+
def __init__(self, dim, eps=1e-5):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.eps = eps
|
| 76 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
r"""
|
| 80 |
+
Args:
|
| 81 |
+
x(Tensor): Shape [B, L, C]
|
| 82 |
+
"""
|
| 83 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 84 |
+
|
| 85 |
+
def _norm(self, x):
|
| 86 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 92 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 93 |
+
|
| 94 |
+
def forward(self, x):
|
| 95 |
+
r"""
|
| 96 |
+
Args:
|
| 97 |
+
x(Tensor): Shape [B, L, C]
|
| 98 |
+
"""
|
| 99 |
+
return super().forward(x.float()).type_as(x)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class WanSelfAttention(nn.Module):
|
| 103 |
+
|
| 104 |
+
def __init__(self,
|
| 105 |
+
dim,
|
| 106 |
+
num_heads,
|
| 107 |
+
window_size=(-1, -1),
|
| 108 |
+
qk_norm=True,
|
| 109 |
+
eps=1e-6):
|
| 110 |
+
assert dim % num_heads == 0
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.dim = dim
|
| 113 |
+
self.num_heads = num_heads
|
| 114 |
+
self.head_dim = dim // num_heads
|
| 115 |
+
self.window_size = window_size
|
| 116 |
+
self.qk_norm = qk_norm
|
| 117 |
+
self.eps = eps
|
| 118 |
+
|
| 119 |
+
# layers
|
| 120 |
+
self.q = nn.Linear(dim, dim)
|
| 121 |
+
self.k = nn.Linear(dim, dim)
|
| 122 |
+
self.v = nn.Linear(dim, dim)
|
| 123 |
+
self.o = nn.Linear(dim, dim)
|
| 124 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 125 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 126 |
+
|
| 127 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 128 |
+
r"""
|
| 129 |
+
Args:
|
| 130 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 131 |
+
seq_lens(Tensor): Shape [B]
|
| 132 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 133 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 134 |
+
"""
|
| 135 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 136 |
+
|
| 137 |
+
# query, key, value function
|
| 138 |
+
def qkv_fn(x):
|
| 139 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 140 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 141 |
+
v = self.v(x).view(b, s, n, d)
|
| 142 |
+
return q, k, v
|
| 143 |
+
|
| 144 |
+
q, k, v = qkv_fn(x)
|
| 145 |
+
|
| 146 |
+
x = flash_attention(
|
| 147 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 148 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 149 |
+
v=v,
|
| 150 |
+
k_lens=seq_lens,
|
| 151 |
+
window_size=self.window_size)
|
| 152 |
+
|
| 153 |
+
# output
|
| 154 |
+
x = x.flatten(2)
|
| 155 |
+
x = self.o(x)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
| 160 |
+
|
| 161 |
+
def forward(self, x, context, context_lens):
|
| 162 |
+
r"""
|
| 163 |
+
Args:
|
| 164 |
+
x(Tensor): Shape [B, L1, C]
|
| 165 |
+
context(Tensor): Shape [B, L2, C]
|
| 166 |
+
context_lens(Tensor): Shape [B]
|
| 167 |
+
"""
|
| 168 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 169 |
+
|
| 170 |
+
# compute query, key, value
|
| 171 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 172 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 173 |
+
v = self.v(context).view(b, -1, n, d)
|
| 174 |
+
|
| 175 |
+
# compute attention
|
| 176 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 177 |
+
|
| 178 |
+
# output
|
| 179 |
+
x = x.flatten(2)
|
| 180 |
+
x = self.o(x)
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 185 |
+
|
| 186 |
+
def __init__(self,
|
| 187 |
+
dim,
|
| 188 |
+
num_heads,
|
| 189 |
+
window_size=(-1, -1),
|
| 190 |
+
qk_norm=True,
|
| 191 |
+
eps=1e-6):
|
| 192 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 193 |
+
|
| 194 |
+
self.k_img = nn.Linear(dim, dim)
|
| 195 |
+
self.v_img = nn.Linear(dim, dim)
|
| 196 |
+
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
| 197 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 198 |
+
|
| 199 |
+
def forward(self, x, context, context_lens):
|
| 200 |
+
r"""
|
| 201 |
+
Args:
|
| 202 |
+
x(Tensor): Shape [B, L1, C]
|
| 203 |
+
context(Tensor): Shape [B, L2, C]
|
| 204 |
+
context_lens(Tensor): Shape [B]
|
| 205 |
+
"""
|
| 206 |
+
context_img = context[:, :257]
|
| 207 |
+
context = context[:, 257:]
|
| 208 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 209 |
+
|
| 210 |
+
# compute query, key, value
|
| 211 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 212 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 213 |
+
v = self.v(context).view(b, -1, n, d)
|
| 214 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
| 215 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
| 216 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| 217 |
+
# compute attention
|
| 218 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 219 |
+
|
| 220 |
+
# output
|
| 221 |
+
x = x.flatten(2)
|
| 222 |
+
img_x = img_x.flatten(2)
|
| 223 |
+
x = x + img_x
|
| 224 |
+
x = self.o(x)
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
WAN_CROSSATTENTION_CLASSES = {
|
| 229 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
| 230 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class WanAttentionBlock(nn.Module):
|
| 235 |
+
|
| 236 |
+
def __init__(self,
|
| 237 |
+
cross_attn_type,
|
| 238 |
+
dim,
|
| 239 |
+
ffn_dim,
|
| 240 |
+
num_heads,
|
| 241 |
+
window_size=(-1, -1),
|
| 242 |
+
qk_norm=True,
|
| 243 |
+
cross_attn_norm=False,
|
| 244 |
+
eps=1e-6):
|
| 245 |
+
super().__init__()
|
| 246 |
+
self.dim = dim
|
| 247 |
+
self.ffn_dim = ffn_dim
|
| 248 |
+
self.num_heads = num_heads
|
| 249 |
+
self.window_size = window_size
|
| 250 |
+
self.qk_norm = qk_norm
|
| 251 |
+
self.cross_attn_norm = cross_attn_norm
|
| 252 |
+
self.eps = eps
|
| 253 |
+
|
| 254 |
+
# layers
|
| 255 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 256 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 257 |
+
eps)
|
| 258 |
+
self.norm3 = WanLayerNorm(
|
| 259 |
+
dim, eps,
|
| 260 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 261 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 262 |
+
num_heads,
|
| 263 |
+
(-1, -1),
|
| 264 |
+
qk_norm,
|
| 265 |
+
eps)
|
| 266 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 267 |
+
self.ffn = nn.Sequential(
|
| 268 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 269 |
+
nn.Linear(ffn_dim, dim))
|
| 270 |
+
|
| 271 |
+
# modulation
|
| 272 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 273 |
+
|
| 274 |
+
def forward(
|
| 275 |
+
self,
|
| 276 |
+
x,
|
| 277 |
+
e,
|
| 278 |
+
seq_lens,
|
| 279 |
+
grid_sizes,
|
| 280 |
+
freqs,
|
| 281 |
+
context,
|
| 282 |
+
context_lens,
|
| 283 |
+
):
|
| 284 |
+
r"""
|
| 285 |
+
Args:
|
| 286 |
+
x(Tensor): Shape [B, L, C]
|
| 287 |
+
e(Tensor): Shape [B, 6, C]
|
| 288 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 289 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 290 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 291 |
+
"""
|
| 292 |
+
assert e.dtype == torch.float32
|
| 293 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 294 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
| 295 |
+
assert e[0].dtype == torch.float32
|
| 296 |
+
|
| 297 |
+
# self-attention
|
| 298 |
+
y = self.self_attn(
|
| 299 |
+
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
| 300 |
+
freqs)
|
| 301 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 302 |
+
x = x + y * e[2]
|
| 303 |
+
|
| 304 |
+
# cross-attention & ffn function
|
| 305 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 306 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 307 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 308 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 309 |
+
x = x + y * e[5]
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class Head(nn.Module):
|
| 317 |
+
|
| 318 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 319 |
+
super().__init__()
|
| 320 |
+
self.dim = dim
|
| 321 |
+
self.out_dim = out_dim
|
| 322 |
+
self.patch_size = patch_size
|
| 323 |
+
self.eps = eps
|
| 324 |
+
|
| 325 |
+
# layers
|
| 326 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 327 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 328 |
+
self.head = nn.Linear(dim, out_dim)
|
| 329 |
+
|
| 330 |
+
# modulation
|
| 331 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 332 |
+
|
| 333 |
+
def forward(self, x, e):
|
| 334 |
+
r"""
|
| 335 |
+
Args:
|
| 336 |
+
x(Tensor): Shape [B, L1, C]
|
| 337 |
+
e(Tensor): Shape [B, C]
|
| 338 |
+
"""
|
| 339 |
+
assert e.dtype == torch.float32
|
| 340 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 341 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 342 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 343 |
+
return x
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class MLPProj(torch.nn.Module):
|
| 347 |
+
|
| 348 |
+
def __init__(self, in_dim, out_dim):
|
| 349 |
+
super().__init__()
|
| 350 |
+
|
| 351 |
+
self.proj = torch.nn.Sequential(
|
| 352 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 353 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 354 |
+
torch.nn.LayerNorm(out_dim))
|
| 355 |
+
|
| 356 |
+
def forward(self, image_embeds):
|
| 357 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 358 |
+
return clip_extra_context_tokens
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 362 |
+
r"""
|
| 363 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
ignore_for_config = [
|
| 367 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 368 |
+
]
|
| 369 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 370 |
+
|
| 371 |
+
@register_to_config
|
| 372 |
+
def __init__(self,
|
| 373 |
+
model_type='t2v',
|
| 374 |
+
patch_size=(1, 2, 2),
|
| 375 |
+
text_len=512,
|
| 376 |
+
in_dim=16,
|
| 377 |
+
dim=5120,
|
| 378 |
+
ffn_dim=13824,
|
| 379 |
+
freq_dim=256,
|
| 380 |
+
text_dim=4096,
|
| 381 |
+
out_dim=16,
|
| 382 |
+
num_heads=40,
|
| 383 |
+
num_layers=40,
|
| 384 |
+
window_size=(-1, -1),
|
| 385 |
+
qk_norm=True,
|
| 386 |
+
cross_attn_norm=True,
|
| 387 |
+
eps=1e-6):
|
| 388 |
+
r"""
|
| 389 |
+
Initialize the diffusion model backbone.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 393 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 394 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 395 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 396 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 397 |
+
Fixed length for text embeddings
|
| 398 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 399 |
+
Input video channels (C_in)
|
| 400 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 401 |
+
Hidden dimension of the transformer
|
| 402 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 403 |
+
Intermediate dimension in feed-forward network
|
| 404 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 405 |
+
Dimension for sinusoidal time embeddings
|
| 406 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 407 |
+
Input dimension for text embeddings
|
| 408 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 409 |
+
Output video channels (C_out)
|
| 410 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 411 |
+
Number of attention heads
|
| 412 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 413 |
+
Number of transformer blocks
|
| 414 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 415 |
+
Window size for local attention (-1 indicates global attention)
|
| 416 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 417 |
+
Enable query/key normalization
|
| 418 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 419 |
+
Enable cross-attention normalization
|
| 420 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 421 |
+
Epsilon value for normalization layers
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
super().__init__()
|
| 425 |
+
|
| 426 |
+
assert model_type in ['t2v', 'i2v']
|
| 427 |
+
self.model_type = model_type
|
| 428 |
+
|
| 429 |
+
self.patch_size = patch_size
|
| 430 |
+
self.text_len = text_len
|
| 431 |
+
self.in_dim = in_dim
|
| 432 |
+
self.dim = dim
|
| 433 |
+
self.ffn_dim = ffn_dim
|
| 434 |
+
self.freq_dim = freq_dim
|
| 435 |
+
self.text_dim = text_dim
|
| 436 |
+
self.out_dim = out_dim
|
| 437 |
+
self.num_heads = num_heads
|
| 438 |
+
self.num_layers = num_layers
|
| 439 |
+
self.window_size = window_size
|
| 440 |
+
self.qk_norm = qk_norm
|
| 441 |
+
self.cross_attn_norm = cross_attn_norm
|
| 442 |
+
self.eps = eps
|
| 443 |
+
|
| 444 |
+
# embeddings
|
| 445 |
+
self.patch_embedding = nn.Conv3d(
|
| 446 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 447 |
+
self.text_embedding = nn.Sequential(
|
| 448 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 449 |
+
nn.Linear(dim, dim))
|
| 450 |
+
|
| 451 |
+
self.time_embedding = nn.Sequential(
|
| 452 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 453 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 454 |
+
|
| 455 |
+
# blocks
|
| 456 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 457 |
+
self.blocks = nn.ModuleList([
|
| 458 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 459 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 460 |
+
for _ in range(num_layers)
|
| 461 |
+
])
|
| 462 |
+
|
| 463 |
+
# head
|
| 464 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 465 |
+
|
| 466 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 467 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 468 |
+
d = dim // num_heads
|
| 469 |
+
self.freqs = torch.cat([
|
| 470 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 471 |
+
rope_params(1024, 2 * (d // 6)),
|
| 472 |
+
rope_params(1024, 2 * (d // 6))
|
| 473 |
+
],
|
| 474 |
+
dim=1)
|
| 475 |
+
|
| 476 |
+
if model_type == 'i2v':
|
| 477 |
+
self.img_emb = MLPProj(1280, dim)
|
| 478 |
+
|
| 479 |
+
# initialize weights
|
| 480 |
+
self.init_weights()
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self,
|
| 484 |
+
x,
|
| 485 |
+
t,
|
| 486 |
+
context,
|
| 487 |
+
seq_len,
|
| 488 |
+
clip_fea=None,
|
| 489 |
+
y=None,
|
| 490 |
+
):
|
| 491 |
+
r"""
|
| 492 |
+
Forward pass through the diffusion model
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
x (List[Tensor]):
|
| 496 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 497 |
+
t (Tensor):
|
| 498 |
+
Diffusion timesteps tensor of shape [B]
|
| 499 |
+
context (List[Tensor]):
|
| 500 |
+
List of text embeddings each with shape [L, C]
|
| 501 |
+
seq_len (`int`):
|
| 502 |
+
Maximum sequence length for positional encoding
|
| 503 |
+
clip_fea (Tensor, *optional*):
|
| 504 |
+
CLIP image features for image-to-video mode
|
| 505 |
+
y (List[Tensor], *optional*):
|
| 506 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
List[Tensor]:
|
| 510 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 511 |
+
"""
|
| 512 |
+
if self.model_type == 'i2v':
|
| 513 |
+
assert clip_fea is not None and y is not None
|
| 514 |
+
# params
|
| 515 |
+
device = self.patch_embedding.weight.device
|
| 516 |
+
freqs = self.freqs.to(device)
|
| 517 |
+
|
| 518 |
+
if y is not None:
|
| 519 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 520 |
+
|
| 521 |
+
# embeddings
|
| 522 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 523 |
+
grid_sizes = torch.stack(
|
| 524 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 525 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 526 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 527 |
+
assert seq_lens.max() <= seq_len
|
| 528 |
+
x = torch.cat([
|
| 529 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 530 |
+
dim=1) for u in x
|
| 531 |
+
])
|
| 532 |
+
|
| 533 |
+
# time embeddings
|
| 534 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 535 |
+
e = self.time_embedding(
|
| 536 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
| 537 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 538 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 539 |
+
|
| 540 |
+
# context
|
| 541 |
+
context_lens = None
|
| 542 |
+
context = self.text_embedding(
|
| 543 |
+
torch.stack([
|
| 544 |
+
torch.cat(
|
| 545 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 546 |
+
for u in context
|
| 547 |
+
]))
|
| 548 |
+
|
| 549 |
+
if clip_fea is not None:
|
| 550 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 551 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 552 |
+
|
| 553 |
+
# arguments
|
| 554 |
+
kwargs = dict(
|
| 555 |
+
e=e0,
|
| 556 |
+
seq_lens=seq_lens,
|
| 557 |
+
grid_sizes=grid_sizes,
|
| 558 |
+
freqs=freqs,
|
| 559 |
+
context=context,
|
| 560 |
+
context_lens=context_lens)
|
| 561 |
+
|
| 562 |
+
for block in self.blocks:
|
| 563 |
+
x = block(x, **kwargs)
|
| 564 |
+
|
| 565 |
+
# head
|
| 566 |
+
x = self.head(x, e)
|
| 567 |
+
|
| 568 |
+
# unpatchify
|
| 569 |
+
x = self.unpatchify(x, grid_sizes)
|
| 570 |
+
return [u.float() for u in x]
|
| 571 |
+
|
| 572 |
+
def unpatchify(self, x, grid_sizes):
|
| 573 |
+
r"""
|
| 574 |
+
Reconstruct video tensors from patch embeddings.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
x (List[Tensor]):
|
| 578 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 579 |
+
grid_sizes (Tensor):
|
| 580 |
+
Original spatial-temporal grid dimensions before patching,
|
| 581 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 582 |
+
|
| 583 |
+
Returns:
|
| 584 |
+
List[Tensor]:
|
| 585 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
c = self.out_dim
|
| 589 |
+
out = []
|
| 590 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 591 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 592 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 593 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 594 |
+
out.append(u)
|
| 595 |
+
return out
|
| 596 |
+
|
| 597 |
+
def init_weights(self):
|
| 598 |
+
r"""
|
| 599 |
+
Initialize model parameters using Xavier initialization.
|
| 600 |
+
"""
|
| 601 |
+
|
| 602 |
+
# basic init
|
| 603 |
+
for m in self.modules():
|
| 604 |
+
if isinstance(m, nn.Linear):
|
| 605 |
+
nn.init.xavier_uniform_(m.weight)
|
| 606 |
+
if m.bias is not None:
|
| 607 |
+
nn.init.zeros_(m.bias)
|
| 608 |
+
|
| 609 |
+
# init embeddings
|
| 610 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 611 |
+
for m in self.text_embedding.modules():
|
| 612 |
+
if isinstance(m, nn.Linear):
|
| 613 |
+
nn.init.normal_(m.weight, std=.02)
|
| 614 |
+
for m in self.time_embedding.modules():
|
| 615 |
+
if isinstance(m, nn.Linear):
|
| 616 |
+
nn.init.normal_(m.weight, std=.02)
|
| 617 |
+
|
| 618 |
+
# init output layer
|
| 619 |
+
nn.init.zeros_(self.head.head.weight)
|
humo/models/wan_modules/model_humo.py
ADDED
|
@@ -0,0 +1,803 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from common.distributed import get_device
|
| 5 |
+
from models.audio.audio_proj import AudioProjModel
|
| 6 |
+
|
| 7 |
+
import torch.cuda.amp as amp
|
| 8 |
+
import math
|
| 9 |
+
from humo.models.wan_modules.attention import flash_attention
|
| 10 |
+
from common.distributed.advanced import is_unified_parallel_initialized
|
| 11 |
+
|
| 12 |
+
import types
|
| 13 |
+
|
| 14 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 15 |
+
# preprocess
|
| 16 |
+
assert dim % 2 == 0
|
| 17 |
+
half = dim // 2
|
| 18 |
+
position = position.type(torch.float64)
|
| 19 |
+
|
| 20 |
+
# calculation
|
| 21 |
+
sinusoid = torch.outer(
|
| 22 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 23 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@amp.autocast(enabled=False)
|
| 28 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 29 |
+
assert dim % 2 == 0
|
| 30 |
+
freqs = torch.outer(
|
| 31 |
+
torch.arange(max_seq_len),
|
| 32 |
+
1.0 / torch.pow(theta,
|
| 33 |
+
torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
| 34 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 35 |
+
return freqs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@amp.autocast(enabled=False)
|
| 39 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 40 |
+
n, c = x.size(2), x.size(3) // 2
|
| 41 |
+
|
| 42 |
+
# split freqs
|
| 43 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 44 |
+
|
| 45 |
+
# loop over samples
|
| 46 |
+
output = []
|
| 47 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 48 |
+
seq_len = f * h * w
|
| 49 |
+
|
| 50 |
+
# precompute multipliers
|
| 51 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
|
| 52 |
+
seq_len, n, -1, 2))
|
| 53 |
+
freqs_i = torch.cat([
|
| 54 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 55 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 56 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 57 |
+
],
|
| 58 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 59 |
+
|
| 60 |
+
# apply rotary embedding
|
| 61 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 62 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 63 |
+
|
| 64 |
+
# append to collection
|
| 65 |
+
output.append(x_i)
|
| 66 |
+
return torch.stack(output).float()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class WanRMSNorm(nn.Module):
|
| 70 |
+
|
| 71 |
+
def __init__(self, dim, eps=1e-5):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.dim = dim
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
r"""
|
| 79 |
+
Args:
|
| 80 |
+
x(Tensor): Shape [B, L, C]
|
| 81 |
+
"""
|
| 82 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 83 |
+
|
| 84 |
+
def _norm(self, x):
|
| 85 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 91 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
r"""
|
| 95 |
+
Args:
|
| 96 |
+
x(Tensor): Shape [B, L, C]
|
| 97 |
+
"""
|
| 98 |
+
return super().forward(x.float()).type_as(x)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class WanSelfAttention(nn.Module):
|
| 102 |
+
|
| 103 |
+
def __init__(self,
|
| 104 |
+
dim,
|
| 105 |
+
num_heads,
|
| 106 |
+
window_size=(-1, -1),
|
| 107 |
+
qk_norm=True,
|
| 108 |
+
eps=1e-6):
|
| 109 |
+
assert dim % num_heads == 0
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.dim = dim
|
| 112 |
+
self.num_heads = num_heads
|
| 113 |
+
self.head_dim = dim // num_heads
|
| 114 |
+
self.window_size = window_size
|
| 115 |
+
self.qk_norm = qk_norm
|
| 116 |
+
self.eps = eps
|
| 117 |
+
|
| 118 |
+
# layers
|
| 119 |
+
self.q = nn.Linear(dim, dim)
|
| 120 |
+
self.k = nn.Linear(dim, dim)
|
| 121 |
+
self.v = nn.Linear(dim, dim)
|
| 122 |
+
self.o = nn.Linear(dim, dim)
|
| 123 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 124 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 125 |
+
|
| 126 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 127 |
+
r"""
|
| 128 |
+
Args:
|
| 129 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
|
| 130 |
+
seq_lens(Tensor): Shape [B], tensor([9360])
|
| 131 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
|
| 132 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 133 |
+
"""
|
| 134 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 135 |
+
|
| 136 |
+
# query, key, value function
|
| 137 |
+
def qkv_fn(x):
|
| 138 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 139 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 140 |
+
v = self.v(x).view(b, s, n, d)
|
| 141 |
+
return q, k, v
|
| 142 |
+
|
| 143 |
+
q, k, v = qkv_fn(x)
|
| 144 |
+
|
| 145 |
+
x = flash_attention(
|
| 146 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 147 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 148 |
+
v=v,
|
| 149 |
+
k_lens=seq_lens,
|
| 150 |
+
window_size=self.window_size)
|
| 151 |
+
|
| 152 |
+
# output
|
| 153 |
+
x = x.flatten(2)
|
| 154 |
+
x = self.o(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class WanSelfAttentionSepKVDim(nn.Module):
|
| 159 |
+
|
| 160 |
+
def __init__(self,
|
| 161 |
+
kv_dim,
|
| 162 |
+
dim,
|
| 163 |
+
num_heads,
|
| 164 |
+
window_size=(-1, -1),
|
| 165 |
+
qk_norm=True,
|
| 166 |
+
eps=1e-6):
|
| 167 |
+
assert dim % num_heads == 0
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.dim = dim
|
| 170 |
+
self.num_heads = num_heads
|
| 171 |
+
self.head_dim = dim // num_heads
|
| 172 |
+
self.window_size = window_size
|
| 173 |
+
self.qk_norm = qk_norm
|
| 174 |
+
self.eps = eps
|
| 175 |
+
|
| 176 |
+
# layers
|
| 177 |
+
self.q = nn.Linear(dim, dim)
|
| 178 |
+
self.k = nn.Linear(kv_dim, dim)
|
| 179 |
+
self.v = nn.Linear(kv_dim, dim)
|
| 180 |
+
self.o = nn.Linear(dim, dim)
|
| 181 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 182 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 183 |
+
|
| 184 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 185 |
+
r"""
|
| 186 |
+
Args:
|
| 187 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads], torch.Size([1, 9360, 5120])
|
| 188 |
+
seq_lens(Tensor): Shape [B], tensor([9360])
|
| 189 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W), tensor([[ 6, 30, 52]])
|
| 190 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 191 |
+
"""
|
| 192 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 193 |
+
|
| 194 |
+
# query, key, value function
|
| 195 |
+
def qkv_fn(x):
|
| 196 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 197 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 198 |
+
v = self.v(x).view(b, s, n, d)
|
| 199 |
+
return q, k, v
|
| 200 |
+
|
| 201 |
+
q, k, v = qkv_fn(x)
|
| 202 |
+
|
| 203 |
+
x = flash_attention(
|
| 204 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 205 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 206 |
+
v=v,
|
| 207 |
+
k_lens=seq_lens,
|
| 208 |
+
window_size=self.window_size)
|
| 209 |
+
|
| 210 |
+
# output
|
| 211 |
+
x = x.flatten(2)
|
| 212 |
+
x = self.o(x)
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class WanT2VCrossAttention(WanSelfAttention):
|
| 218 |
+
|
| 219 |
+
def forward(self, x, context, context_lens):
|
| 220 |
+
r"""
|
| 221 |
+
Args:
|
| 222 |
+
x(Tensor): Shape [B, L1, C]
|
| 223 |
+
context(Tensor): Shape [B, L2, C]
|
| 224 |
+
context_lens(Tensor): Shape [B]
|
| 225 |
+
"""
|
| 226 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 227 |
+
|
| 228 |
+
# compute query, key, value
|
| 229 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 230 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 231 |
+
v = self.v(context).view(b, -1, n, d)
|
| 232 |
+
|
| 233 |
+
# compute attention
|
| 234 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 235 |
+
|
| 236 |
+
# output
|
| 237 |
+
x = x.flatten(2)
|
| 238 |
+
x = self.o(x)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class WanT2VCrossAttentionGather(WanSelfAttentionSepKVDim):
|
| 243 |
+
|
| 244 |
+
def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
|
| 245 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 246 |
+
|
| 247 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 248 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 249 |
+
v = self.v(context).view(b, -1, n, d)
|
| 250 |
+
|
| 251 |
+
# --- NEW: derive sizes from shapes (SymInts), no int(tensor) casts ---
|
| 252 |
+
Lq = q.shape[1] # total video tokens per sample
|
| 253 |
+
# audio has 16 tokens per frame -> frames = audio_tokens // 16
|
| 254 |
+
frames = (context.shape[1] // 16)
|
| 255 |
+
hlen_wlen = Lq // frames # tokens per frame = H*W
|
| 256 |
+
|
| 257 |
+
# Now reshape using SymInt-derived sizes
|
| 258 |
+
q = q.reshape(-1, hlen_wlen, n, d)
|
| 259 |
+
k = k.reshape(-1, 16, n, d)
|
| 260 |
+
v = v.reshape(-1, 16, n, d)
|
| 261 |
+
|
| 262 |
+
x = flash_attention(q, k, v, k_lens=None)
|
| 263 |
+
x = x.view(b, -1, n, d).flatten(2)
|
| 264 |
+
x = self.o(x)
|
| 265 |
+
return x
|
| 266 |
+
|
| 267 |
+
# def forward(self, x, context, context_lens, grid_sizes, freqs, audio_seq_len):
|
| 268 |
+
# r"""
|
| 269 |
+
# Args:
|
| 270 |
+
# x(Tensor): Shape [B, L1, C] - video tokens
|
| 271 |
+
# context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
|
| 272 |
+
# context_lens(Tensor): Shape [B] - actually seq_lens from call (video sequence length)
|
| 273 |
+
# grid_sizes(Tensor): Shape [B, 3] - video grid dimensions (F, H, W)
|
| 274 |
+
# freqs(Tensor): RoPE frequencies
|
| 275 |
+
# audio_seq_len(Tensor): Actual audio sequence length (frames * 16)
|
| 276 |
+
# """
|
| 277 |
+
# b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 278 |
+
|
| 279 |
+
# q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 280 |
+
# k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 281 |
+
# v = self.v(context).view(b, -1, n, d)
|
| 282 |
+
|
| 283 |
+
# # Handle video spatial structure
|
| 284 |
+
# hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
|
| 285 |
+
# q = q.reshape(-1, hlen_wlen, n, d)
|
| 286 |
+
|
| 287 |
+
# # Handle audio temporal structure (16 tokens per frame)
|
| 288 |
+
# k = k.reshape(-1, 16, n, d)
|
| 289 |
+
# v = v.reshape(-1, 16, n, d)
|
| 290 |
+
|
| 291 |
+
# # Cross-attention
|
| 292 |
+
# x = flash_attention(q, k, v, k_lens=None) # No masking for audio
|
| 293 |
+
|
| 294 |
+
# x = x.view(b, -1, n, d).flatten(2)
|
| 295 |
+
# x = self.o(x)
|
| 296 |
+
# return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class AudioCrossAttentionWrapper(nn.Module):
|
| 300 |
+
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6,):
|
| 301 |
+
super().__init__()
|
| 302 |
+
|
| 303 |
+
self.audio_cross_attn = WanT2VCrossAttentionGather(
|
| 304 |
+
kv_dim, dim, num_heads, (-1, -1), qk_norm, eps)
|
| 305 |
+
self.norm1_audio = WanLayerNorm(dim, eps,
|
| 306 |
+
elementwise_affine=True)
|
| 307 |
+
|
| 308 |
+
def forward(self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len):
|
| 309 |
+
x = x + self.audio_cross_attn(
|
| 310 |
+
self.norm1_audio(x), audio, seq_lens, grid_sizes, freqs, audio_seq_len)
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class WanI2VCrossAttention(WanSelfAttention):
|
| 315 |
+
|
| 316 |
+
def __init__(self,
|
| 317 |
+
dim,
|
| 318 |
+
num_heads,
|
| 319 |
+
window_size=(-1, -1),
|
| 320 |
+
qk_norm=True,
|
| 321 |
+
eps=1e-6):
|
| 322 |
+
super().__init__(dim, num_heads, window_size, qk_norm, eps)
|
| 323 |
+
|
| 324 |
+
def forward(self, x, context, context_lens):
|
| 325 |
+
r"""
|
| 326 |
+
Args:
|
| 327 |
+
x(Tensor): Shape [B, L1, C]
|
| 328 |
+
context(Tensor): Shape [B, L2, C]
|
| 329 |
+
context_lens(Tensor): Shape [B]
|
| 330 |
+
"""
|
| 331 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 332 |
+
|
| 333 |
+
# compute query, key, value
|
| 334 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 335 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 336 |
+
v = self.v(context).view(b, -1, n, d)
|
| 337 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 338 |
+
|
| 339 |
+
# output
|
| 340 |
+
x = x.flatten(2)
|
| 341 |
+
x = self.o(x)
|
| 342 |
+
return x
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
WAN_CROSSATTENTION_CLASSES = {
|
| 346 |
+
't2v_cross_attn': WanT2VCrossAttention,
|
| 347 |
+
'i2v_cross_attn': WanI2VCrossAttention,
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
class WanAttentionBlock(nn.Module):
|
| 351 |
+
|
| 352 |
+
def __init__(self,
|
| 353 |
+
cross_attn_type,
|
| 354 |
+
dim,
|
| 355 |
+
ffn_dim,
|
| 356 |
+
num_heads,
|
| 357 |
+
window_size=(-1, -1),
|
| 358 |
+
qk_norm=True,
|
| 359 |
+
cross_attn_norm=False,
|
| 360 |
+
eps=1e-6,
|
| 361 |
+
use_audio=True):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.dim = dim
|
| 364 |
+
self.ffn_dim = ffn_dim
|
| 365 |
+
self.num_heads = num_heads
|
| 366 |
+
self.window_size = window_size
|
| 367 |
+
self.qk_norm = qk_norm
|
| 368 |
+
self.cross_attn_norm = cross_attn_norm
|
| 369 |
+
self.eps = eps
|
| 370 |
+
|
| 371 |
+
# layers
|
| 372 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 373 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 374 |
+
eps)
|
| 375 |
+
self.norm3 = WanLayerNorm(
|
| 376 |
+
dim, eps,
|
| 377 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 378 |
+
self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
|
| 379 |
+
num_heads,
|
| 380 |
+
(-1, -1),
|
| 381 |
+
qk_norm,
|
| 382 |
+
eps)
|
| 383 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 384 |
+
self.ffn = nn.Sequential(
|
| 385 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 386 |
+
nn.Linear(ffn_dim, dim))
|
| 387 |
+
|
| 388 |
+
# modulation
|
| 389 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 390 |
+
|
| 391 |
+
self.use_audio = use_audio
|
| 392 |
+
if use_audio:
|
| 393 |
+
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps)
|
| 394 |
+
|
| 395 |
+
def forward(
|
| 396 |
+
self,
|
| 397 |
+
x, # torch.Size([1, 9360, 5120])
|
| 398 |
+
e, # torch.Size([1, 6, 5120])
|
| 399 |
+
seq_lens, # tensor([9360])
|
| 400 |
+
grid_sizes, # tensor([[ 6, 30, 52]])
|
| 401 |
+
freqs, # torch.Size([1024, 64])
|
| 402 |
+
context, # torch.Size([1, 512, 5120])
|
| 403 |
+
context_lens, # None
|
| 404 |
+
audio=None, # None
|
| 405 |
+
audio_seq_len=None,
|
| 406 |
+
ref_num_list=None,
|
| 407 |
+
):
|
| 408 |
+
r"""
|
| 409 |
+
Args:
|
| 410 |
+
x(Tensor): Shape [B, L, C]
|
| 411 |
+
e(Tensor): Shape [B, L, C]
|
| 412 |
+
audio(Tensor): Shape [B, L, C]
|
| 413 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 414 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 415 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 416 |
+
ref_num_list: 配合seq_lens可以查到reference image在倒数第几个
|
| 417 |
+
"""
|
| 418 |
+
assert e.dtype == torch.float32
|
| 419 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 420 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
| 421 |
+
assert e[0].dtype == torch.float32
|
| 422 |
+
|
| 423 |
+
# self-attention
|
| 424 |
+
y = self.self_attn(
|
| 425 |
+
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
|
| 426 |
+
freqs)
|
| 427 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 428 |
+
x = x + y * e[2]
|
| 429 |
+
|
| 430 |
+
# cross-attention & ffn function
|
| 431 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 432 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 433 |
+
|
| 434 |
+
if self.use_audio:
|
| 435 |
+
x = self.audio_cross_attn_wrapper(x, audio, seq_lens, grid_sizes, freqs, audio_seq_len)
|
| 436 |
+
|
| 437 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 438 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 439 |
+
x = x + y * e[5]
|
| 440 |
+
return x
|
| 441 |
+
|
| 442 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 443 |
+
|
| 444 |
+
return x
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class Head(nn.Module):
|
| 448 |
+
|
| 449 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 450 |
+
super().__init__()
|
| 451 |
+
self.dim = dim
|
| 452 |
+
self.out_dim = out_dim
|
| 453 |
+
self.patch_size = patch_size
|
| 454 |
+
self.eps = eps
|
| 455 |
+
|
| 456 |
+
# layers
|
| 457 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 458 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 459 |
+
self.head = nn.Linear(dim, out_dim)
|
| 460 |
+
|
| 461 |
+
# modulation
|
| 462 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 463 |
+
|
| 464 |
+
def forward(self, x, e):
|
| 465 |
+
r"""
|
| 466 |
+
Args:
|
| 467 |
+
x(Tensor): Shape [B, L1, C]
|
| 468 |
+
e(Tensor): Shape [B, C]
|
| 469 |
+
"""
|
| 470 |
+
assert e.dtype == torch.float32
|
| 471 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 472 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 473 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 474 |
+
return x
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class MLPProj(torch.nn.Module):
|
| 478 |
+
|
| 479 |
+
def __init__(self, in_dim, out_dim):
|
| 480 |
+
super().__init__()
|
| 481 |
+
|
| 482 |
+
self.proj = torch.nn.Sequential(
|
| 483 |
+
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
|
| 484 |
+
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
|
| 485 |
+
torch.nn.LayerNorm(out_dim))
|
| 486 |
+
|
| 487 |
+
def forward(self, image_embeds):
|
| 488 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 489 |
+
return clip_extra_context_tokens
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class WanModel(nn.Module):
|
| 493 |
+
r"""
|
| 494 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
ignore_for_config = [
|
| 498 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 499 |
+
]
|
| 500 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 501 |
+
|
| 502 |
+
gradient_checkpointing = False
|
| 503 |
+
|
| 504 |
+
def __init__(self,
|
| 505 |
+
model_type='t2v',
|
| 506 |
+
patch_size=(1, 2, 2),
|
| 507 |
+
text_len=512,
|
| 508 |
+
in_dim=16,
|
| 509 |
+
dim=2048,
|
| 510 |
+
ffn_dim=13824,
|
| 511 |
+
freq_dim=256,
|
| 512 |
+
text_dim=4096,
|
| 513 |
+
out_dim=16,
|
| 514 |
+
num_heads=40,
|
| 515 |
+
num_layers=40,
|
| 516 |
+
window_size=(-1, -1),
|
| 517 |
+
qk_norm=True,
|
| 518 |
+
cross_attn_norm=True,
|
| 519 |
+
eps=1e-6,
|
| 520 |
+
audio_token_num=16,
|
| 521 |
+
insert_audio=True):
|
| 522 |
+
r"""
|
| 523 |
+
Initialize the diffusion model backbone.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 527 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 528 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 529 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 530 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 531 |
+
Fixed length for text embeddings
|
| 532 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 533 |
+
Input video channels (C_in)
|
| 534 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 535 |
+
Hidden dimension of the transformer
|
| 536 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 537 |
+
Intermediate dimension in feed-forward network
|
| 538 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 539 |
+
Dimension for sinusoidal time embeddings
|
| 540 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 541 |
+
Input dimension for text embeddings
|
| 542 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 543 |
+
Output video channels (C_out)
|
| 544 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 545 |
+
Number of attention heads
|
| 546 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 547 |
+
Number of transformer blocks
|
| 548 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 549 |
+
Window size for local attention (-1 indicates global attention)
|
| 550 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 551 |
+
Enable query/key normalization
|
| 552 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 553 |
+
Enable cross-attention normalization
|
| 554 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 555 |
+
Epsilon value for normalization layers
|
| 556 |
+
"""
|
| 557 |
+
|
| 558 |
+
super().__init__()
|
| 559 |
+
|
| 560 |
+
assert model_type in ['t2v', 'i2v']
|
| 561 |
+
self.model_type = model_type
|
| 562 |
+
|
| 563 |
+
self.patch_size = patch_size
|
| 564 |
+
self.text_len = text_len
|
| 565 |
+
self.in_dim = in_dim
|
| 566 |
+
self.dim = dim
|
| 567 |
+
self.ffn_dim = ffn_dim
|
| 568 |
+
self.freq_dim = freq_dim
|
| 569 |
+
self.text_dim = text_dim
|
| 570 |
+
self.out_dim = out_dim
|
| 571 |
+
self.num_heads = num_heads
|
| 572 |
+
self.num_layers = num_layers
|
| 573 |
+
self.window_size = window_size
|
| 574 |
+
self.qk_norm = qk_norm
|
| 575 |
+
self.cross_attn_norm = cross_attn_norm
|
| 576 |
+
self.eps = eps
|
| 577 |
+
|
| 578 |
+
# embeddings
|
| 579 |
+
self.patch_embedding = nn.Conv3d(
|
| 580 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 581 |
+
self.text_embedding = nn.Sequential(
|
| 582 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 583 |
+
nn.Linear(dim, dim))
|
| 584 |
+
|
| 585 |
+
self.time_embedding = nn.Sequential(
|
| 586 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 587 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 588 |
+
|
| 589 |
+
# blocks
|
| 590 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 591 |
+
self.insert_audio = insert_audio
|
| 592 |
+
self.blocks = nn.ModuleList([
|
| 593 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 594 |
+
window_size, qk_norm, cross_attn_norm,
|
| 595 |
+
eps, use_audio=self.insert_audio)
|
| 596 |
+
for _ in range(num_layers)
|
| 597 |
+
])
|
| 598 |
+
|
| 599 |
+
# head
|
| 600 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 601 |
+
|
| 602 |
+
if self.insert_audio:
|
| 603 |
+
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280,
|
| 604 |
+
intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num)
|
| 605 |
+
|
| 606 |
+
# RoPE freqs: register as a buffer so it moves with .to() / DDP and is tracked by compile
|
| 607 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 608 |
+
d = dim // num_heads
|
| 609 |
+
|
| 610 |
+
_freqs = torch.cat([
|
| 611 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 612 |
+
rope_params(1024, 2 * (d // 6)),
|
| 613 |
+
rope_params(1024, 2 * (d // 6))
|
| 614 |
+
], dim=1)
|
| 615 |
+
self.register_buffer("freqs", _freqs, persistent=False)
|
| 616 |
+
|
| 617 |
+
# initialize weights
|
| 618 |
+
self.init_weights()
|
| 619 |
+
|
| 620 |
+
# initialize unified parallel
|
| 621 |
+
if is_unified_parallel_initialized():
|
| 622 |
+
print(f"Initializing WanModel with unified parallel initialized")
|
| 623 |
+
from humo.models.distributed.dit_ulysses_sequence_parallel import ulysses_attn_forward, ulysses_dit_forward, ulysses_audio_cross_attn_forward
|
| 624 |
+
for block in self.blocks:
|
| 625 |
+
block.self_attn.forward = types.MethodType(ulysses_attn_forward, block.self_attn)
|
| 626 |
+
if block.use_audio:
|
| 627 |
+
block.audio_cross_attn_wrapper.audio_cross_attn.forward = types.MethodType(ulysses_audio_cross_attn_forward, block.audio_cross_attn_wrapper.audio_cross_attn)
|
| 628 |
+
self.forward = types.MethodType(ulysses_dit_forward, self)
|
| 629 |
+
|
| 630 |
+
def forward(
|
| 631 |
+
self,
|
| 632 |
+
x,
|
| 633 |
+
t,
|
| 634 |
+
context,
|
| 635 |
+
seq_len,
|
| 636 |
+
audio=None,
|
| 637 |
+
y=None,
|
| 638 |
+
tea_cache=None,
|
| 639 |
+
):
|
| 640 |
+
r"""
|
| 641 |
+
Forward pass through the diffusion model
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
x (List[Tensor]):
|
| 645 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 646 |
+
t (Tensor):
|
| 647 |
+
Diffusion timesteps tensor of shape [B]
|
| 648 |
+
context (List[Tensor]):
|
| 649 |
+
List of text embeddings each with shape [L, C]
|
| 650 |
+
seq_len (`int`):
|
| 651 |
+
Maximum sequence length for positional encoding
|
| 652 |
+
clip_fea (Tensor, *optional*):
|
| 653 |
+
CLIP image features for image-to-video mode
|
| 654 |
+
y (List[Tensor], *optional*):
|
| 655 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 656 |
+
|
| 657 |
+
Returns:
|
| 658 |
+
List[Tensor]:
|
| 659 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 660 |
+
"""
|
| 661 |
+
if self.model_type == 'i2v':
|
| 662 |
+
assert y is not None
|
| 663 |
+
|
| 664 |
+
# params
|
| 665 |
+
freqs = self.freqs
|
| 666 |
+
|
| 667 |
+
if y is not None:
|
| 668 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 669 |
+
|
| 670 |
+
# embeddings
|
| 671 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 672 |
+
grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 673 |
+
|
| 674 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 675 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 676 |
+
assert seq_lens.max() <= seq_len
|
| 677 |
+
|
| 678 |
+
# pad to uniform length and batch
|
| 679 |
+
x = torch.cat([
|
| 680 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 681 |
+
for u in x
|
| 682 |
+
]) # shape: [B, seq_len, C]
|
| 683 |
+
|
| 684 |
+
# time embeddings
|
| 685 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 686 |
+
e = self.time_embedding(
|
| 687 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float()
|
| 688 |
+
).float()
|
| 689 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
|
| 690 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 691 |
+
|
| 692 |
+
# context
|
| 693 |
+
context_lens = None
|
| 694 |
+
context = self.text_embedding(
|
| 695 |
+
torch.stack([
|
| 696 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 697 |
+
for u in context
|
| 698 |
+
])
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
# audio (unchanged; not cached)
|
| 702 |
+
if self.insert_audio:
|
| 703 |
+
audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
|
| 704 |
+
audio_seq_len = max(au.shape[2] for au in audio) * audio[0].shape[3]
|
| 705 |
+
|
| 706 |
+
audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
|
| 707 |
+
audio = torch.cat([
|
| 708 |
+
torch.cat([au, au.new_zeros(1, int(audio_seq_len) - au.size(1), au.size(2))], dim=1)
|
| 709 |
+
for au in audio
|
| 710 |
+
])
|
| 711 |
+
else:
|
| 712 |
+
audio = None
|
| 713 |
+
audio_seq_len = None
|
| 714 |
+
|
| 715 |
+
# ---- tea_cache integration (mirrors your working model) ----
|
| 716 |
+
if tea_cache is not None:
|
| 717 |
+
# Use the pre-block tokens 'x' and time-mod 'e0' to decide whether to reuse cache
|
| 718 |
+
tea_cache_update = tea_cache.check(self, x, e0)
|
| 719 |
+
else:
|
| 720 |
+
tea_cache_update = False
|
| 721 |
+
|
| 722 |
+
ori_x_len = x.shape[1] # remember original token length before potential cache extension
|
| 723 |
+
|
| 724 |
+
if tea_cache_update:
|
| 725 |
+
# Let the cache inject/append any needed past states/tokens for reuse
|
| 726 |
+
x = tea_cache.update(x)
|
| 727 |
+
else:
|
| 728 |
+
# arguments for blocks
|
| 729 |
+
kwargs = dict(
|
| 730 |
+
e=e0,
|
| 731 |
+
seq_lens=seq_lens,
|
| 732 |
+
grid_sizes=grid_sizes,
|
| 733 |
+
freqs=freqs,
|
| 734 |
+
context=context,
|
| 735 |
+
context_lens=context_lens,
|
| 736 |
+
audio=audio,
|
| 737 |
+
audio_seq_len=audio_seq_len
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# transformer blocks
|
| 741 |
+
for block in self.blocks:
|
| 742 |
+
x = block(x, **kwargs)
|
| 743 |
+
|
| 744 |
+
if tea_cache is not None:
|
| 745 |
+
x_cache = x[:, :ori_x_len]
|
| 746 |
+
tea_cache.store(x_cache)
|
| 747 |
+
|
| 748 |
+
# head
|
| 749 |
+
x = self.head(x, e)
|
| 750 |
+
|
| 751 |
+
# unpatchify
|
| 752 |
+
x = self.unpatchify(x, grid_sizes)
|
| 753 |
+
return [u.float() for u in x]
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def unpatchify(self, x, grid_sizes):
|
| 757 |
+
r"""
|
| 758 |
+
Reconstruct video tensors from patch embeddings.
|
| 759 |
+
|
| 760 |
+
Args:
|
| 761 |
+
x (List[Tensor]):
|
| 762 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 763 |
+
grid_sizes (Tensor):
|
| 764 |
+
Original spatial-temporal grid dimensions before patching,
|
| 765 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 766 |
+
|
| 767 |
+
Returns:
|
| 768 |
+
List[Tensor]:
|
| 769 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
c = self.out_dim
|
| 773 |
+
out = []
|
| 774 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 775 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 776 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 777 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 778 |
+
out.append(u)
|
| 779 |
+
return out
|
| 780 |
+
|
| 781 |
+
def init_weights(self):
|
| 782 |
+
r"""
|
| 783 |
+
Initialize model parameters using Xavier initialization.
|
| 784 |
+
"""
|
| 785 |
+
|
| 786 |
+
# basic init
|
| 787 |
+
for m in self.modules():
|
| 788 |
+
if isinstance(m, nn.Linear):
|
| 789 |
+
nn.init.xavier_uniform_(m.weight)
|
| 790 |
+
if m.bias is not None:
|
| 791 |
+
nn.init.zeros_(m.bias)
|
| 792 |
+
|
| 793 |
+
# init embeddings
|
| 794 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 795 |
+
for m in self.text_embedding.modules():
|
| 796 |
+
if isinstance(m, nn.Linear):
|
| 797 |
+
nn.init.normal_(m.weight, std=.02)
|
| 798 |
+
for m in self.time_embedding.modules():
|
| 799 |
+
if isinstance(m, nn.Linear):
|
| 800 |
+
nn.init.normal_(m.weight, std=.02)
|
| 801 |
+
|
| 802 |
+
# init output layer
|
| 803 |
+
nn.init.zeros_(self.head.head.weight)
|
humo/models/wan_modules/t5.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.t5.modeling_t5
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'T5Model',
|
| 14 |
+
'T5Encoder',
|
| 15 |
+
'T5Decoder',
|
| 16 |
+
'T5EncoderModel',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fp16_clamp(x):
|
| 21 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 22 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 23 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_weights(m):
|
| 28 |
+
if isinstance(m, T5LayerNorm):
|
| 29 |
+
nn.init.ones_(m.weight)
|
| 30 |
+
elif isinstance(m, T5Model):
|
| 31 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
| 32 |
+
elif isinstance(m, T5FeedForward):
|
| 33 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 34 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 35 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 36 |
+
elif isinstance(m, T5Attention):
|
| 37 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
| 38 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 39 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 40 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
| 41 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 42 |
+
nn.init.normal_(
|
| 43 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GELU(nn.Module):
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
| 50 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class T5LayerNorm(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self, dim, eps=1e-6):
|
| 56 |
+
super(T5LayerNorm, self).__init__()
|
| 57 |
+
self.dim = dim
|
| 58 |
+
self.eps = eps
|
| 59 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
| 63 |
+
self.eps)
|
| 64 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 65 |
+
x = x.type_as(self.weight)
|
| 66 |
+
return self.weight * x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class T5Attention(nn.Module):
|
| 70 |
+
|
| 71 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 72 |
+
assert dim_attn % num_heads == 0
|
| 73 |
+
super(T5Attention, self).__init__()
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.dim_attn = dim_attn
|
| 76 |
+
self.num_heads = num_heads
|
| 77 |
+
self.head_dim = dim_attn // num_heads
|
| 78 |
+
|
| 79 |
+
# layers
|
| 80 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 81 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 82 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 83 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 84 |
+
self.dropout = nn.Dropout(dropout)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 87 |
+
"""
|
| 88 |
+
x: [B, L1, C].
|
| 89 |
+
context: [B, L2, C] or None.
|
| 90 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 91 |
+
"""
|
| 92 |
+
# check inputs
|
| 93 |
+
context = x if context is None else context
|
| 94 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 95 |
+
|
| 96 |
+
# compute query, key, value
|
| 97 |
+
q = self.q(x).view(b, -1, n, c)
|
| 98 |
+
k = self.k(context).view(b, -1, n, c)
|
| 99 |
+
v = self.v(context).view(b, -1, n, c)
|
| 100 |
+
|
| 101 |
+
# attention bias
|
| 102 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 103 |
+
if pos_bias is not None:
|
| 104 |
+
attn_bias += pos_bias
|
| 105 |
+
if mask is not None:
|
| 106 |
+
assert mask.ndim in [2, 3]
|
| 107 |
+
mask = mask.view(b, 1, 1,
|
| 108 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 109 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 110 |
+
|
| 111 |
+
# compute attention (T5 does not use scaling)
|
| 112 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
| 113 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 114 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
| 115 |
+
|
| 116 |
+
# output
|
| 117 |
+
x = x.reshape(b, -1, n * c)
|
| 118 |
+
x = self.o(x)
|
| 119 |
+
x = self.dropout(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class T5FeedForward(nn.Module):
|
| 124 |
+
|
| 125 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 126 |
+
super(T5FeedForward, self).__init__()
|
| 127 |
+
self.dim = dim
|
| 128 |
+
self.dim_ffn = dim_ffn
|
| 129 |
+
|
| 130 |
+
# layers
|
| 131 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 132 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 133 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 134 |
+
self.dropout = nn.Dropout(dropout)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = self.fc1(x) * self.gate(x)
|
| 138 |
+
x = self.dropout(x)
|
| 139 |
+
x = self.fc2(x)
|
| 140 |
+
x = self.dropout(x)
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class T5SelfAttention(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self,
|
| 147 |
+
dim,
|
| 148 |
+
dim_attn,
|
| 149 |
+
dim_ffn,
|
| 150 |
+
num_heads,
|
| 151 |
+
num_buckets,
|
| 152 |
+
shared_pos=True,
|
| 153 |
+
dropout=0.1):
|
| 154 |
+
super(T5SelfAttention, self).__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.dim_attn = dim_attn
|
| 157 |
+
self.dim_ffn = dim_ffn
|
| 158 |
+
self.num_heads = num_heads
|
| 159 |
+
self.num_buckets = num_buckets
|
| 160 |
+
self.shared_pos = shared_pos
|
| 161 |
+
|
| 162 |
+
# layers
|
| 163 |
+
self.norm1 = T5LayerNorm(dim)
|
| 164 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 165 |
+
self.norm2 = T5LayerNorm(dim)
|
| 166 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 167 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 168 |
+
num_buckets, num_heads, bidirectional=True)
|
| 169 |
+
|
| 170 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 171 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 172 |
+
x.size(1), x.size(1))
|
| 173 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 174 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class T5CrossAttention(nn.Module):
|
| 179 |
+
|
| 180 |
+
def __init__(self,
|
| 181 |
+
dim,
|
| 182 |
+
dim_attn,
|
| 183 |
+
dim_ffn,
|
| 184 |
+
num_heads,
|
| 185 |
+
num_buckets,
|
| 186 |
+
shared_pos=True,
|
| 187 |
+
dropout=0.1):
|
| 188 |
+
super(T5CrossAttention, self).__init__()
|
| 189 |
+
self.dim = dim
|
| 190 |
+
self.dim_attn = dim_attn
|
| 191 |
+
self.dim_ffn = dim_ffn
|
| 192 |
+
self.num_heads = num_heads
|
| 193 |
+
self.num_buckets = num_buckets
|
| 194 |
+
self.shared_pos = shared_pos
|
| 195 |
+
|
| 196 |
+
# layers
|
| 197 |
+
self.norm1 = T5LayerNorm(dim)
|
| 198 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 199 |
+
self.norm2 = T5LayerNorm(dim)
|
| 200 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 201 |
+
self.norm3 = T5LayerNorm(dim)
|
| 202 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 203 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 204 |
+
num_buckets, num_heads, bidirectional=False)
|
| 205 |
+
|
| 206 |
+
def forward(self,
|
| 207 |
+
x,
|
| 208 |
+
mask=None,
|
| 209 |
+
encoder_states=None,
|
| 210 |
+
encoder_mask=None,
|
| 211 |
+
pos_bias=None):
|
| 212 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 213 |
+
x.size(1), x.size(1))
|
| 214 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 215 |
+
x = fp16_clamp(x + self.cross_attn(
|
| 216 |
+
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
| 217 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class T5RelativeEmbedding(nn.Module):
|
| 222 |
+
|
| 223 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 224 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 225 |
+
self.num_buckets = num_buckets
|
| 226 |
+
self.num_heads = num_heads
|
| 227 |
+
self.bidirectional = bidirectional
|
| 228 |
+
self.max_dist = max_dist
|
| 229 |
+
|
| 230 |
+
# layers
|
| 231 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 232 |
+
|
| 233 |
+
def forward(self, lq, lk):
|
| 234 |
+
device = self.embedding.weight.device
|
| 235 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 236 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 237 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
| 238 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
| 239 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 240 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 241 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
| 242 |
+
0) # [1, N, Lq, Lk]
|
| 243 |
+
return rel_pos_embeds.contiguous()
|
| 244 |
+
|
| 245 |
+
def _relative_position_bucket(self, rel_pos):
|
| 246 |
+
# preprocess
|
| 247 |
+
if self.bidirectional:
|
| 248 |
+
num_buckets = self.num_buckets // 2
|
| 249 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 250 |
+
rel_pos = torch.abs(rel_pos)
|
| 251 |
+
else:
|
| 252 |
+
num_buckets = self.num_buckets
|
| 253 |
+
rel_buckets = 0
|
| 254 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 255 |
+
|
| 256 |
+
# embeddings for small and large positions
|
| 257 |
+
max_exact = num_buckets // 2
|
| 258 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
| 259 |
+
math.log(self.max_dist / max_exact) *
|
| 260 |
+
(num_buckets - max_exact)).long()
|
| 261 |
+
rel_pos_large = torch.min(
|
| 262 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
| 263 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 264 |
+
return rel_buckets
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class T5Encoder(nn.Module):
|
| 268 |
+
|
| 269 |
+
def __init__(self,
|
| 270 |
+
vocab,
|
| 271 |
+
dim,
|
| 272 |
+
dim_attn,
|
| 273 |
+
dim_ffn,
|
| 274 |
+
num_heads,
|
| 275 |
+
num_layers,
|
| 276 |
+
num_buckets,
|
| 277 |
+
shared_pos=True,
|
| 278 |
+
dropout=0.1):
|
| 279 |
+
super(T5Encoder, self).__init__()
|
| 280 |
+
self.dim = dim
|
| 281 |
+
self.dim_attn = dim_attn
|
| 282 |
+
self.dim_ffn = dim_ffn
|
| 283 |
+
self.num_heads = num_heads
|
| 284 |
+
self.num_layers = num_layers
|
| 285 |
+
self.num_buckets = num_buckets
|
| 286 |
+
self.shared_pos = shared_pos
|
| 287 |
+
|
| 288 |
+
# layers
|
| 289 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 290 |
+
else nn.Embedding(vocab, dim)
|
| 291 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 292 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
| 293 |
+
self.dropout = nn.Dropout(dropout)
|
| 294 |
+
self.blocks = nn.ModuleList([
|
| 295 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 296 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 297 |
+
])
|
| 298 |
+
self.norm = T5LayerNorm(dim)
|
| 299 |
+
|
| 300 |
+
# initialize weights
|
| 301 |
+
self.apply(init_weights)
|
| 302 |
+
|
| 303 |
+
def forward(self, ids, mask=None):
|
| 304 |
+
x = self.token_embedding(ids)
|
| 305 |
+
x = self.dropout(x)
|
| 306 |
+
e = self.pos_embedding(x.size(1),
|
| 307 |
+
x.size(1)) if self.shared_pos else None
|
| 308 |
+
for block in self.blocks:
|
| 309 |
+
x = block(x, mask, pos_bias=e)
|
| 310 |
+
x = self.norm(x)
|
| 311 |
+
x = self.dropout(x)
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class T5Decoder(nn.Module):
|
| 316 |
+
|
| 317 |
+
def __init__(self,
|
| 318 |
+
vocab,
|
| 319 |
+
dim,
|
| 320 |
+
dim_attn,
|
| 321 |
+
dim_ffn,
|
| 322 |
+
num_heads,
|
| 323 |
+
num_layers,
|
| 324 |
+
num_buckets,
|
| 325 |
+
shared_pos=True,
|
| 326 |
+
dropout=0.1):
|
| 327 |
+
super(T5Decoder, self).__init__()
|
| 328 |
+
self.dim = dim
|
| 329 |
+
self.dim_attn = dim_attn
|
| 330 |
+
self.dim_ffn = dim_ffn
|
| 331 |
+
self.num_heads = num_heads
|
| 332 |
+
self.num_layers = num_layers
|
| 333 |
+
self.num_buckets = num_buckets
|
| 334 |
+
self.shared_pos = shared_pos
|
| 335 |
+
|
| 336 |
+
# layers
|
| 337 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 338 |
+
else nn.Embedding(vocab, dim)
|
| 339 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 340 |
+
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
| 341 |
+
self.dropout = nn.Dropout(dropout)
|
| 342 |
+
self.blocks = nn.ModuleList([
|
| 343 |
+
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 344 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 345 |
+
])
|
| 346 |
+
self.norm = T5LayerNorm(dim)
|
| 347 |
+
|
| 348 |
+
# initialize weights
|
| 349 |
+
self.apply(init_weights)
|
| 350 |
+
|
| 351 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
| 352 |
+
b, s = ids.size()
|
| 353 |
+
|
| 354 |
+
# causal mask
|
| 355 |
+
if mask is None:
|
| 356 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
| 357 |
+
elif mask.ndim == 2:
|
| 358 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
| 359 |
+
|
| 360 |
+
# layers
|
| 361 |
+
x = self.token_embedding(ids)
|
| 362 |
+
x = self.dropout(x)
|
| 363 |
+
e = self.pos_embedding(x.size(1),
|
| 364 |
+
x.size(1)) if self.shared_pos else None
|
| 365 |
+
for block in self.blocks:
|
| 366 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
| 367 |
+
x = self.norm(x)
|
| 368 |
+
x = self.dropout(x)
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class T5Model(nn.Module):
|
| 373 |
+
|
| 374 |
+
def __init__(self,
|
| 375 |
+
vocab_size,
|
| 376 |
+
dim,
|
| 377 |
+
dim_attn,
|
| 378 |
+
dim_ffn,
|
| 379 |
+
num_heads,
|
| 380 |
+
encoder_layers,
|
| 381 |
+
decoder_layers,
|
| 382 |
+
num_buckets,
|
| 383 |
+
shared_pos=True,
|
| 384 |
+
dropout=0.1):
|
| 385 |
+
super(T5Model, self).__init__()
|
| 386 |
+
self.vocab_size = vocab_size
|
| 387 |
+
self.dim = dim
|
| 388 |
+
self.dim_attn = dim_attn
|
| 389 |
+
self.dim_ffn = dim_ffn
|
| 390 |
+
self.num_heads = num_heads
|
| 391 |
+
self.encoder_layers = encoder_layers
|
| 392 |
+
self.decoder_layers = decoder_layers
|
| 393 |
+
self.num_buckets = num_buckets
|
| 394 |
+
|
| 395 |
+
# layers
|
| 396 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
| 397 |
+
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 398 |
+
num_heads, encoder_layers, num_buckets,
|
| 399 |
+
shared_pos, dropout)
|
| 400 |
+
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 401 |
+
num_heads, decoder_layers, num_buckets,
|
| 402 |
+
shared_pos, dropout)
|
| 403 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 404 |
+
|
| 405 |
+
# initialize weights
|
| 406 |
+
self.apply(init_weights)
|
| 407 |
+
|
| 408 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
| 409 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
| 410 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
| 411 |
+
x = self.head(x)
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _t5(name,
|
| 416 |
+
encoder_only=False,
|
| 417 |
+
decoder_only=False,
|
| 418 |
+
return_tokenizer=False,
|
| 419 |
+
tokenizer_kwargs={},
|
| 420 |
+
dtype=torch.float32,
|
| 421 |
+
device='cpu',
|
| 422 |
+
**kwargs):
|
| 423 |
+
# sanity check
|
| 424 |
+
assert not (encoder_only and decoder_only)
|
| 425 |
+
|
| 426 |
+
# params
|
| 427 |
+
if encoder_only:
|
| 428 |
+
model_cls = T5Encoder
|
| 429 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 430 |
+
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
| 431 |
+
_ = kwargs.pop('decoder_layers')
|
| 432 |
+
elif decoder_only:
|
| 433 |
+
model_cls = T5Decoder
|
| 434 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 435 |
+
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
| 436 |
+
_ = kwargs.pop('encoder_layers')
|
| 437 |
+
else:
|
| 438 |
+
model_cls = T5Model
|
| 439 |
+
|
| 440 |
+
# init model
|
| 441 |
+
with torch.device(device):
|
| 442 |
+
model = model_cls(**kwargs)
|
| 443 |
+
|
| 444 |
+
# set device
|
| 445 |
+
model = model.to(dtype=dtype, device=device)
|
| 446 |
+
|
| 447 |
+
# init tokenizer
|
| 448 |
+
if return_tokenizer:
|
| 449 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 450 |
+
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
| 451 |
+
return model, tokenizer
|
| 452 |
+
else:
|
| 453 |
+
return model
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def umt5_xxl(**kwargs):
|
| 457 |
+
cfg = dict(
|
| 458 |
+
vocab_size=256384,
|
| 459 |
+
dim=4096,
|
| 460 |
+
dim_attn=4096,
|
| 461 |
+
dim_ffn=10240,
|
| 462 |
+
num_heads=64,
|
| 463 |
+
encoder_layers=24,
|
| 464 |
+
decoder_layers=24,
|
| 465 |
+
num_buckets=32,
|
| 466 |
+
shared_pos=False,
|
| 467 |
+
dropout=0.1)
|
| 468 |
+
cfg.update(**kwargs)
|
| 469 |
+
return _t5('umt5-xxl', **cfg)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class T5EncoderModel(nn.Module):
|
| 473 |
+
|
| 474 |
+
def __init__(
|
| 475 |
+
self,
|
| 476 |
+
text_len,
|
| 477 |
+
dtype=torch.bfloat16,
|
| 478 |
+
device=torch.cuda.current_device(),
|
| 479 |
+
checkpoint_path=None,
|
| 480 |
+
tokenizer_path=None,
|
| 481 |
+
shard_fn=None,
|
| 482 |
+
):
|
| 483 |
+
super(T5EncoderModel, self).__init__()
|
| 484 |
+
self.text_len = text_len
|
| 485 |
+
self.dtype = dtype
|
| 486 |
+
self.device = device
|
| 487 |
+
self.checkpoint_path = checkpoint_path
|
| 488 |
+
self.tokenizer_path = tokenizer_path
|
| 489 |
+
|
| 490 |
+
with torch.device(device):
|
| 491 |
+
self.model = T5Encoder(
|
| 492 |
+
vocab=256384,
|
| 493 |
+
dim=4096,
|
| 494 |
+
dim_attn=4096,
|
| 495 |
+
dim_ffn=10240,
|
| 496 |
+
num_heads=64,
|
| 497 |
+
num_layers=24,
|
| 498 |
+
num_buckets=32,
|
| 499 |
+
shared_pos=False,
|
| 500 |
+
dropout=0.1
|
| 501 |
+
)
|
| 502 |
+
# set device
|
| 503 |
+
self.model = self.model.to(dtype=dtype, device=device).eval().requires_grad_(False)
|
| 504 |
+
|
| 505 |
+
logging.info(f'loading {checkpoint_path}')
|
| 506 |
+
if checkpoint_path is not None:
|
| 507 |
+
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
| 508 |
+
|
| 509 |
+
if shard_fn is not None:
|
| 510 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
| 511 |
+
else:
|
| 512 |
+
self.model.to(self.device)
|
| 513 |
+
# init tokenizer
|
| 514 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 515 |
+
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
| 516 |
+
|
| 517 |
+
@torch.no_grad()
|
| 518 |
+
def __call__(self, texts, device):
|
| 519 |
+
ids, mask = self.tokenizer(
|
| 520 |
+
texts, return_mask=True, add_special_tokens=True)
|
| 521 |
+
ids = ids.to(device)
|
| 522 |
+
mask = mask.to(device)
|
| 523 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 524 |
+
context = self.model(ids, mask)
|
| 525 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|
humo/models/wan_modules/tokenizers.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import html
|
| 3 |
+
import string
|
| 4 |
+
|
| 5 |
+
import ftfy
|
| 6 |
+
import regex as re
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
__all__ = ['HuggingfaceTokenizer']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def basic_clean(text):
|
| 13 |
+
text = ftfy.fix_text(text)
|
| 14 |
+
text = html.unescape(html.unescape(text))
|
| 15 |
+
return text.strip()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def whitespace_clean(text):
|
| 19 |
+
text = re.sub(r'\s+', ' ', text)
|
| 20 |
+
text = text.strip()
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 25 |
+
text = text.replace('_', ' ')
|
| 26 |
+
if keep_punctuation_exact_string:
|
| 27 |
+
text = keep_punctuation_exact_string.join(
|
| 28 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
| 29 |
+
for part in text.split(keep_punctuation_exact_string))
|
| 30 |
+
else:
|
| 31 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
| 32 |
+
text = text.lower()
|
| 33 |
+
text = re.sub(r'\s+', ' ', text)
|
| 34 |
+
return text.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HuggingfaceTokenizer:
|
| 38 |
+
|
| 39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 40 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
| 41 |
+
self.name = name
|
| 42 |
+
self.seq_len = seq_len
|
| 43 |
+
self.clean = clean
|
| 44 |
+
|
| 45 |
+
# init tokenizer
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 48 |
+
|
| 49 |
+
def __call__(self, sequence, **kwargs):
|
| 50 |
+
return_mask = kwargs.pop('return_mask', False)
|
| 51 |
+
|
| 52 |
+
# arguments
|
| 53 |
+
_kwargs = {'return_tensors': 'pt'}
|
| 54 |
+
if self.seq_len is not None:
|
| 55 |
+
_kwargs.update({
|
| 56 |
+
'padding': 'max_length',
|
| 57 |
+
'truncation': True,
|
| 58 |
+
'max_length': self.seq_len
|
| 59 |
+
})
|
| 60 |
+
_kwargs.update(**kwargs)
|
| 61 |
+
|
| 62 |
+
# tokenization
|
| 63 |
+
if isinstance(sequence, str):
|
| 64 |
+
sequence = [sequence]
|
| 65 |
+
if self.clean:
|
| 66 |
+
sequence = [self._clean(u) for u in sequence]
|
| 67 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 68 |
+
|
| 69 |
+
# output
|
| 70 |
+
if return_mask:
|
| 71 |
+
return ids.input_ids, ids.attention_mask
|
| 72 |
+
else:
|
| 73 |
+
return ids.input_ids
|
| 74 |
+
|
| 75 |
+
def _clean(self, text):
|
| 76 |
+
if self.clean == 'whitespace':
|
| 77 |
+
text = whitespace_clean(basic_clean(text))
|
| 78 |
+
elif self.clean == 'lower':
|
| 79 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 80 |
+
elif self.clean == 'canonicalize':
|
| 81 |
+
text = canonicalize(basic_clean(text))
|
| 82 |
+
return text
|
humo/models/wan_modules/vae.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'WanVAE',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
CACHE_T = 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalConv3d(nn.Conv3d):
|
| 18 |
+
"""
|
| 19 |
+
Causal 3d convolusion.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 25 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 26 |
+
self.padding = (0, 0, 0)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, cache_x=None):
|
| 29 |
+
padding = list(self._padding)
|
| 30 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 31 |
+
cache_x = cache_x.to(x.device)
|
| 32 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 33 |
+
padding[4] -= cache_x.shape[2]
|
| 34 |
+
x = F.pad(x, padding)
|
| 35 |
+
|
| 36 |
+
return super().forward(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RMS_norm(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 42 |
+
super().__init__()
|
| 43 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 44 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 45 |
+
|
| 46 |
+
self.channel_first = channel_first
|
| 47 |
+
self.scale = dim**0.5
|
| 48 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 49 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return F.normalize(
|
| 53 |
+
x, dim=(1 if self.channel_first else
|
| 54 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Upsample(nn.Upsample):
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
"""
|
| 61 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 62 |
+
"""
|
| 63 |
+
return super().forward(x.float()).type_as(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Resample(nn.Module):
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim, mode):
|
| 69 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 70 |
+
'downsample3d')
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.dim = dim
|
| 73 |
+
self.mode = mode
|
| 74 |
+
|
| 75 |
+
# layers
|
| 76 |
+
if mode == 'upsample2d':
|
| 77 |
+
self.resample = nn.Sequential(
|
| 78 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 79 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 80 |
+
elif mode == 'upsample3d':
|
| 81 |
+
self.resample = nn.Sequential(
|
| 82 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 83 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 84 |
+
self.time_conv = CausalConv3d(
|
| 85 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 86 |
+
|
| 87 |
+
elif mode == 'downsample2d':
|
| 88 |
+
self.resample = nn.Sequential(
|
| 89 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 90 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 91 |
+
elif mode == 'downsample3d':
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 95 |
+
self.time_conv = CausalConv3d(
|
| 96 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
self.resample = nn.Identity()
|
| 100 |
+
|
| 101 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 102 |
+
b, c, t, h, w = x.size()
|
| 103 |
+
if self.mode == 'upsample3d':
|
| 104 |
+
if feat_cache is not None:
|
| 105 |
+
idx = feat_idx[0]
|
| 106 |
+
if feat_cache[idx] is None:
|
| 107 |
+
feat_cache[idx] = 'Rep'
|
| 108 |
+
feat_idx[0] += 1
|
| 109 |
+
else:
|
| 110 |
+
|
| 111 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 112 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 113 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 114 |
+
# cache last frame of last two chunk
|
| 115 |
+
cache_x = torch.cat([
|
| 116 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 117 |
+
cache_x.device), cache_x
|
| 118 |
+
],
|
| 119 |
+
dim=2)
|
| 120 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 121 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 122 |
+
cache_x = torch.cat([
|
| 123 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 124 |
+
cache_x
|
| 125 |
+
],
|
| 126 |
+
dim=2)
|
| 127 |
+
if feat_cache[idx] == 'Rep':
|
| 128 |
+
x = self.time_conv(x)
|
| 129 |
+
else:
|
| 130 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 131 |
+
feat_cache[idx] = cache_x
|
| 132 |
+
feat_idx[0] += 1
|
| 133 |
+
|
| 134 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 135 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 136 |
+
3)
|
| 137 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 138 |
+
t = x.shape[2]
|
| 139 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 140 |
+
x = self.resample(x)
|
| 141 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 142 |
+
|
| 143 |
+
if self.mode == 'downsample3d':
|
| 144 |
+
if feat_cache is not None:
|
| 145 |
+
idx = feat_idx[0]
|
| 146 |
+
if feat_cache[idx] is None:
|
| 147 |
+
feat_cache[idx] = x.clone()
|
| 148 |
+
feat_idx[0] += 1
|
| 149 |
+
else:
|
| 150 |
+
|
| 151 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 152 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 153 |
+
# # cache last frame of last two chunk
|
| 154 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 155 |
+
|
| 156 |
+
x = self.time_conv(
|
| 157 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 158 |
+
feat_cache[idx] = cache_x
|
| 159 |
+
feat_idx[0] += 1
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
def init_weight(self, conv):
|
| 163 |
+
conv_weight = conv.weight
|
| 164 |
+
nn.init.zeros_(conv_weight)
|
| 165 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 166 |
+
one_matrix = torch.eye(c1, c2)
|
| 167 |
+
init_matrix = one_matrix
|
| 168 |
+
nn.init.zeros_(conv_weight)
|
| 169 |
+
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 170 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
| 171 |
+
conv.weight.data.copy_(conv_weight)
|
| 172 |
+
nn.init.zeros_(conv.bias.data)
|
| 173 |
+
|
| 174 |
+
def init_weight2(self, conv):
|
| 175 |
+
conv_weight = conv.weight.data
|
| 176 |
+
nn.init.zeros_(conv_weight)
|
| 177 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 178 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 179 |
+
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 180 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 181 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 182 |
+
conv.weight.data.copy_(conv_weight)
|
| 183 |
+
nn.init.zeros_(conv.bias.data)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ResidualBlock(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_dim = in_dim
|
| 191 |
+
self.out_dim = out_dim
|
| 192 |
+
|
| 193 |
+
# layers
|
| 194 |
+
self.residual = nn.Sequential(
|
| 195 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 196 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 197 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 198 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 199 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 200 |
+
if in_dim != out_dim else nn.Identity()
|
| 201 |
+
|
| 202 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 203 |
+
h = self.shortcut(x)
|
| 204 |
+
for layer in self.residual:
|
| 205 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 206 |
+
idx = feat_idx[0]
|
| 207 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 208 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 209 |
+
# cache last frame of last two chunk
|
| 210 |
+
cache_x = torch.cat([
|
| 211 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 212 |
+
cache_x.device), cache_x
|
| 213 |
+
],
|
| 214 |
+
dim=2)
|
| 215 |
+
x = layer(x, feat_cache[idx])
|
| 216 |
+
feat_cache[idx] = cache_x
|
| 217 |
+
feat_idx[0] += 1
|
| 218 |
+
else:
|
| 219 |
+
x = layer(x)
|
| 220 |
+
return x + h
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class AttentionBlock(nn.Module):
|
| 224 |
+
"""
|
| 225 |
+
Causal self-attention with a single head.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self, dim):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.dim = dim
|
| 231 |
+
|
| 232 |
+
# layers
|
| 233 |
+
self.norm = RMS_norm(dim)
|
| 234 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 235 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 236 |
+
|
| 237 |
+
# zero out the last layer params
|
| 238 |
+
nn.init.zeros_(self.proj.weight)
|
| 239 |
+
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
identity = x
|
| 242 |
+
b, c, t, h, w = x.size()
|
| 243 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 244 |
+
x = self.norm(x)
|
| 245 |
+
# compute query, key, value
|
| 246 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 247 |
+
-1).permute(0, 1, 3,
|
| 248 |
+
2).contiguous().chunk(
|
| 249 |
+
3, dim=-1)
|
| 250 |
+
|
| 251 |
+
# apply attention
|
| 252 |
+
x = F.scaled_dot_product_attention(
|
| 253 |
+
q,
|
| 254 |
+
k,
|
| 255 |
+
v,
|
| 256 |
+
)
|
| 257 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 258 |
+
|
| 259 |
+
# output
|
| 260 |
+
x = self.proj(x)
|
| 261 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
| 262 |
+
return x + identity
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class Encoder3d(nn.Module):
|
| 266 |
+
|
| 267 |
+
def __init__(self,
|
| 268 |
+
dim=128,
|
| 269 |
+
z_dim=4,
|
| 270 |
+
dim_mult=[1, 2, 4, 4],
|
| 271 |
+
num_res_blocks=2,
|
| 272 |
+
attn_scales=[],
|
| 273 |
+
temperal_downsample=[True, True, False],
|
| 274 |
+
dropout=0.0):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.dim = dim
|
| 277 |
+
self.z_dim = z_dim
|
| 278 |
+
self.dim_mult = dim_mult
|
| 279 |
+
self.num_res_blocks = num_res_blocks
|
| 280 |
+
self.attn_scales = attn_scales
|
| 281 |
+
self.temperal_downsample = temperal_downsample
|
| 282 |
+
|
| 283 |
+
# dimensions
|
| 284 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 285 |
+
scale = 1.0
|
| 286 |
+
|
| 287 |
+
# init block
|
| 288 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 289 |
+
|
| 290 |
+
# downsample blocks
|
| 291 |
+
downsamples = []
|
| 292 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 293 |
+
# residual (+attention) blocks
|
| 294 |
+
for _ in range(num_res_blocks):
|
| 295 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 296 |
+
if scale in attn_scales:
|
| 297 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 298 |
+
in_dim = out_dim
|
| 299 |
+
|
| 300 |
+
# downsample block
|
| 301 |
+
if i != len(dim_mult) - 1:
|
| 302 |
+
mode = 'downsample3d' if temperal_downsample[
|
| 303 |
+
i] else 'downsample2d'
|
| 304 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 305 |
+
scale /= 2.0
|
| 306 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 307 |
+
|
| 308 |
+
# middle blocks
|
| 309 |
+
self.middle = nn.Sequential(
|
| 310 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
| 311 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
| 312 |
+
|
| 313 |
+
# output blocks
|
| 314 |
+
self.head = nn.Sequential(
|
| 315 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 316 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 317 |
+
|
| 318 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 319 |
+
if feat_cache is not None:
|
| 320 |
+
idx = feat_idx[0]
|
| 321 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 322 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 323 |
+
# cache last frame of last two chunk
|
| 324 |
+
cache_x = torch.cat([
|
| 325 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 326 |
+
cache_x.device), cache_x
|
| 327 |
+
],
|
| 328 |
+
dim=2)
|
| 329 |
+
x = self.conv1(x, feat_cache[idx])
|
| 330 |
+
feat_cache[idx] = cache_x
|
| 331 |
+
feat_idx[0] += 1
|
| 332 |
+
else:
|
| 333 |
+
x = self.conv1(x)
|
| 334 |
+
|
| 335 |
+
## downsamples
|
| 336 |
+
for layer in self.downsamples:
|
| 337 |
+
if feat_cache is not None:
|
| 338 |
+
x = layer(x, feat_cache, feat_idx)
|
| 339 |
+
else:
|
| 340 |
+
x = layer(x)
|
| 341 |
+
|
| 342 |
+
## middle
|
| 343 |
+
for layer in self.middle:
|
| 344 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 345 |
+
x = layer(x, feat_cache, feat_idx)
|
| 346 |
+
else:
|
| 347 |
+
x = layer(x)
|
| 348 |
+
|
| 349 |
+
## head
|
| 350 |
+
for layer in self.head:
|
| 351 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 352 |
+
idx = feat_idx[0]
|
| 353 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 354 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 355 |
+
# cache last frame of last two chunk
|
| 356 |
+
cache_x = torch.cat([
|
| 357 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 358 |
+
cache_x.device), cache_x
|
| 359 |
+
],
|
| 360 |
+
dim=2)
|
| 361 |
+
x = layer(x, feat_cache[idx])
|
| 362 |
+
feat_cache[idx] = cache_x
|
| 363 |
+
feat_idx[0] += 1
|
| 364 |
+
else:
|
| 365 |
+
x = layer(x)
|
| 366 |
+
return x
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class Decoder3d(nn.Module):
|
| 370 |
+
|
| 371 |
+
def __init__(self,
|
| 372 |
+
dim=128,
|
| 373 |
+
z_dim=4,
|
| 374 |
+
dim_mult=[1, 2, 4, 4],
|
| 375 |
+
num_res_blocks=2,
|
| 376 |
+
attn_scales=[],
|
| 377 |
+
temperal_upsample=[False, True, True],
|
| 378 |
+
dropout=0.0):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.dim = dim
|
| 381 |
+
self.z_dim = z_dim
|
| 382 |
+
self.dim_mult = dim_mult
|
| 383 |
+
self.num_res_blocks = num_res_blocks
|
| 384 |
+
self.attn_scales = attn_scales
|
| 385 |
+
self.temperal_upsample = temperal_upsample
|
| 386 |
+
|
| 387 |
+
# dimensions
|
| 388 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 389 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 390 |
+
|
| 391 |
+
# init block
|
| 392 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 393 |
+
|
| 394 |
+
# middle blocks
|
| 395 |
+
self.middle = nn.Sequential(
|
| 396 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 397 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 398 |
+
|
| 399 |
+
# upsample blocks
|
| 400 |
+
upsamples = []
|
| 401 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 402 |
+
# residual (+attention) blocks
|
| 403 |
+
if i == 1 or i == 2 or i == 3:
|
| 404 |
+
in_dim = in_dim // 2
|
| 405 |
+
for _ in range(num_res_blocks + 1):
|
| 406 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 407 |
+
if scale in attn_scales:
|
| 408 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 409 |
+
in_dim = out_dim
|
| 410 |
+
|
| 411 |
+
# upsample block
|
| 412 |
+
if i != len(dim_mult) - 1:
|
| 413 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 414 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 415 |
+
scale *= 2.0
|
| 416 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 417 |
+
|
| 418 |
+
# output blocks
|
| 419 |
+
self.head = nn.Sequential(
|
| 420 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 421 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 422 |
+
|
| 423 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 424 |
+
## conv1
|
| 425 |
+
if feat_cache is not None:
|
| 426 |
+
idx = feat_idx[0]
|
| 427 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 428 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 429 |
+
# cache last frame of last two chunk
|
| 430 |
+
cache_x = torch.cat([
|
| 431 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 432 |
+
cache_x.device), cache_x
|
| 433 |
+
],
|
| 434 |
+
dim=2)
|
| 435 |
+
x = self.conv1(x, feat_cache[idx])
|
| 436 |
+
feat_cache[idx] = cache_x
|
| 437 |
+
feat_idx[0] += 1
|
| 438 |
+
else:
|
| 439 |
+
x = self.conv1(x)
|
| 440 |
+
|
| 441 |
+
## middle
|
| 442 |
+
for layer in self.middle:
|
| 443 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 444 |
+
x = layer(x, feat_cache, feat_idx)
|
| 445 |
+
else:
|
| 446 |
+
x = layer(x)
|
| 447 |
+
|
| 448 |
+
## upsamples
|
| 449 |
+
for layer in self.upsamples:
|
| 450 |
+
if feat_cache is not None:
|
| 451 |
+
x = layer(x, feat_cache, feat_idx)
|
| 452 |
+
else:
|
| 453 |
+
x = layer(x)
|
| 454 |
+
|
| 455 |
+
## head
|
| 456 |
+
for layer in self.head:
|
| 457 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 458 |
+
idx = feat_idx[0]
|
| 459 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 460 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 461 |
+
# cache last frame of last two chunk
|
| 462 |
+
cache_x = torch.cat([
|
| 463 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 464 |
+
cache_x.device), cache_x
|
| 465 |
+
],
|
| 466 |
+
dim=2)
|
| 467 |
+
x = layer(x, feat_cache[idx])
|
| 468 |
+
feat_cache[idx] = cache_x
|
| 469 |
+
feat_idx[0] += 1
|
| 470 |
+
else:
|
| 471 |
+
x = layer(x)
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def count_conv3d(model):
|
| 476 |
+
count = 0
|
| 477 |
+
for m in model.modules():
|
| 478 |
+
if isinstance(m, CausalConv3d):
|
| 479 |
+
count += 1
|
| 480 |
+
return count
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class WanVAE_(nn.Module):
|
| 484 |
+
|
| 485 |
+
def __init__(self,
|
| 486 |
+
dim=128,
|
| 487 |
+
z_dim=4,
|
| 488 |
+
dim_mult=[1, 2, 4, 4],
|
| 489 |
+
num_res_blocks=2,
|
| 490 |
+
attn_scales=[],
|
| 491 |
+
temperal_downsample=[True, True, False],
|
| 492 |
+
dropout=0.0):
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.dim = dim
|
| 495 |
+
self.z_dim = z_dim
|
| 496 |
+
self.dim_mult = dim_mult
|
| 497 |
+
self.num_res_blocks = num_res_blocks
|
| 498 |
+
self.attn_scales = attn_scales
|
| 499 |
+
self.temperal_downsample = temperal_downsample
|
| 500 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 501 |
+
|
| 502 |
+
# modules
|
| 503 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
| 504 |
+
attn_scales, self.temperal_downsample, dropout)
|
| 505 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 506 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 507 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
| 508 |
+
attn_scales, self.temperal_upsample, dropout)
|
| 509 |
+
|
| 510 |
+
def forward(self, x):
|
| 511 |
+
mu, log_var = self.encode(x)
|
| 512 |
+
z = self.reparameterize(mu, log_var)
|
| 513 |
+
x_recon = self.decode(z)
|
| 514 |
+
return x_recon, mu, log_var
|
| 515 |
+
|
| 516 |
+
def encode(self, x, scale):
|
| 517 |
+
self.clear_cache()
|
| 518 |
+
## cache
|
| 519 |
+
t = x.shape[2]
|
| 520 |
+
iter_ = 1 + (t - 1) // 4
|
| 521 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
| 522 |
+
for i in range(iter_):
|
| 523 |
+
self._enc_conv_idx = [0]
|
| 524 |
+
if i == 0:
|
| 525 |
+
out = self.encoder(
|
| 526 |
+
x[:, :, :1, :, :],
|
| 527 |
+
feat_cache=self._enc_feat_map,
|
| 528 |
+
feat_idx=self._enc_conv_idx)
|
| 529 |
+
else:
|
| 530 |
+
out_ = self.encoder(
|
| 531 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 532 |
+
feat_cache=self._enc_feat_map,
|
| 533 |
+
feat_idx=self._enc_conv_idx)
|
| 534 |
+
out = torch.cat([out, out_], 2)
|
| 535 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 536 |
+
if isinstance(scale[0], torch.Tensor):
|
| 537 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 538 |
+
1, self.z_dim, 1, 1, 1)
|
| 539 |
+
else:
|
| 540 |
+
mu = (mu - scale[0]) * scale[1]
|
| 541 |
+
self.clear_cache()
|
| 542 |
+
return mu
|
| 543 |
+
|
| 544 |
+
def decode(self, z, scale):
|
| 545 |
+
self.clear_cache()
|
| 546 |
+
# z: [b,c,t,h,w]
|
| 547 |
+
if isinstance(scale[0], torch.Tensor):
|
| 548 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 549 |
+
1, self.z_dim, 1, 1, 1)
|
| 550 |
+
else:
|
| 551 |
+
z = z / scale[1] + scale[0]
|
| 552 |
+
iter_ = z.shape[2]
|
| 553 |
+
x = self.conv2(z)
|
| 554 |
+
for i in range(iter_):
|
| 555 |
+
self._conv_idx = [0]
|
| 556 |
+
if i == 0:
|
| 557 |
+
out = self.decoder(
|
| 558 |
+
x[:, :, i:i + 1, :, :],
|
| 559 |
+
feat_cache=self._feat_map,
|
| 560 |
+
feat_idx=self._conv_idx)
|
| 561 |
+
else:
|
| 562 |
+
out_ = self.decoder(
|
| 563 |
+
x[:, :, i:i + 1, :, :],
|
| 564 |
+
feat_cache=self._feat_map,
|
| 565 |
+
feat_idx=self._conv_idx)
|
| 566 |
+
out = torch.cat([out, out_], 2)
|
| 567 |
+
self.clear_cache()
|
| 568 |
+
return out
|
| 569 |
+
|
| 570 |
+
def reparameterize(self, mu, log_var):
|
| 571 |
+
std = torch.exp(0.5 * log_var)
|
| 572 |
+
eps = torch.randn_like(std)
|
| 573 |
+
return eps * std + mu
|
| 574 |
+
|
| 575 |
+
def sample(self, imgs, deterministic=False):
|
| 576 |
+
mu, log_var = self.encode(imgs)
|
| 577 |
+
if deterministic:
|
| 578 |
+
return mu
|
| 579 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 580 |
+
return mu + std * torch.randn_like(std)
|
| 581 |
+
|
| 582 |
+
def clear_cache(self):
|
| 583 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 584 |
+
self._conv_idx = [0]
|
| 585 |
+
self._feat_map = [None] * self._conv_num
|
| 586 |
+
#cache encode
|
| 587 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 588 |
+
self._enc_conv_idx = [0]
|
| 589 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
| 593 |
+
"""
|
| 594 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
| 595 |
+
"""
|
| 596 |
+
# params
|
| 597 |
+
cfg = dict(
|
| 598 |
+
dim=96,
|
| 599 |
+
z_dim=z_dim,
|
| 600 |
+
dim_mult=[1, 2, 4, 4],
|
| 601 |
+
num_res_blocks=2,
|
| 602 |
+
attn_scales=[],
|
| 603 |
+
temperal_downsample=[False, True, True],
|
| 604 |
+
dropout=0.0)
|
| 605 |
+
cfg.update(**kwargs)
|
| 606 |
+
|
| 607 |
+
# init model
|
| 608 |
+
# with torch.device('meta'):
|
| 609 |
+
model = WanVAE_(**cfg)
|
| 610 |
+
|
| 611 |
+
# load checkpoint
|
| 612 |
+
logging.info(f'loading {pretrained_path}')
|
| 613 |
+
if pretrained_path is not None:
|
| 614 |
+
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
| 615 |
+
|
| 616 |
+
return model
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class WanVAE:
|
| 620 |
+
|
| 621 |
+
def __init__(self,
|
| 622 |
+
z_dim=16,
|
| 623 |
+
vae_pth=None,
|
| 624 |
+
dtype=torch.float,
|
| 625 |
+
device="cuda"):
|
| 626 |
+
self.dtype = dtype
|
| 627 |
+
self.device = device
|
| 628 |
+
|
| 629 |
+
mean = [
|
| 630 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 631 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 632 |
+
]
|
| 633 |
+
std = [
|
| 634 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 635 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 636 |
+
]
|
| 637 |
+
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
| 638 |
+
self.std = torch.tensor(std, dtype=dtype, device=device)
|
| 639 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 640 |
+
|
| 641 |
+
# init model
|
| 642 |
+
self.model = _video_vae(
|
| 643 |
+
pretrained_path=vae_pth,
|
| 644 |
+
z_dim=z_dim,
|
| 645 |
+
).eval().requires_grad_(False).to(device)
|
| 646 |
+
|
| 647 |
+
@torch.no_grad()
|
| 648 |
+
def encode(self, videos, device):
|
| 649 |
+
"""
|
| 650 |
+
videos: A list of videos each with shape [C, T, H, W].
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
with torch.amp.autocast('cuda', dtype=self.dtype):
|
| 654 |
+
return [
|
| 655 |
+
self.model.encode(u.unsqueeze(0).to(device,self.dtype), self.scale).float().squeeze(0)
|
| 656 |
+
for u in videos
|
| 657 |
+
]
|
| 658 |
+
|
| 659 |
+
@torch.no_grad()
|
| 660 |
+
def decode(self, zs):
|
| 661 |
+
with torch.amp.autocast('cuda', dtype=self.dtype):
|
| 662 |
+
return [
|
| 663 |
+
self.model.decode(u.unsqueeze(0),
|
| 664 |
+
self.scale).float().clamp_(-1, 1).squeeze(0)
|
| 665 |
+
for u in zs
|
| 666 |
+
]
|
humo/models/wan_modules/xlm_roberta.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
__all__ = ['XLMRoberta', 'xlm_roberta_large']
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SelfAttention(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
|
| 13 |
+
assert dim % num_heads == 0
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.num_heads = num_heads
|
| 17 |
+
self.head_dim = dim // num_heads
|
| 18 |
+
self.eps = eps
|
| 19 |
+
|
| 20 |
+
# layers
|
| 21 |
+
self.q = nn.Linear(dim, dim)
|
| 22 |
+
self.k = nn.Linear(dim, dim)
|
| 23 |
+
self.v = nn.Linear(dim, dim)
|
| 24 |
+
self.o = nn.Linear(dim, dim)
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, mask):
|
| 28 |
+
"""
|
| 29 |
+
x: [B, L, C].
|
| 30 |
+
"""
|
| 31 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 32 |
+
|
| 33 |
+
# compute query, key, value
|
| 34 |
+
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 35 |
+
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 36 |
+
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
|
| 37 |
+
|
| 38 |
+
# compute attention
|
| 39 |
+
p = self.dropout.p if self.training else 0.0
|
| 40 |
+
x = F.scaled_dot_product_attention(q, k, v, mask, p)
|
| 41 |
+
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
|
| 42 |
+
|
| 43 |
+
# output
|
| 44 |
+
x = self.o(x)
|
| 45 |
+
x = self.dropout(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AttentionBlock(nn.Module):
|
| 50 |
+
|
| 51 |
+
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.dim = dim
|
| 54 |
+
self.num_heads = num_heads
|
| 55 |
+
self.post_norm = post_norm
|
| 56 |
+
self.eps = eps
|
| 57 |
+
|
| 58 |
+
# layers
|
| 59 |
+
self.attn = SelfAttention(dim, num_heads, dropout, eps)
|
| 60 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps)
|
| 61 |
+
self.ffn = nn.Sequential(
|
| 62 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
|
| 63 |
+
nn.Dropout(dropout))
|
| 64 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, mask):
|
| 67 |
+
if self.post_norm:
|
| 68 |
+
x = self.norm1(x + self.attn(x, mask))
|
| 69 |
+
x = self.norm2(x + self.ffn(x))
|
| 70 |
+
else:
|
| 71 |
+
x = x + self.attn(self.norm1(x), mask)
|
| 72 |
+
x = x + self.ffn(self.norm2(x))
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class XLMRoberta(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
XLMRobertaModel with no pooler and no LM head.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self,
|
| 82 |
+
vocab_size=250002,
|
| 83 |
+
max_seq_len=514,
|
| 84 |
+
type_size=1,
|
| 85 |
+
pad_id=1,
|
| 86 |
+
dim=1024,
|
| 87 |
+
num_heads=16,
|
| 88 |
+
num_layers=24,
|
| 89 |
+
post_norm=True,
|
| 90 |
+
dropout=0.1,
|
| 91 |
+
eps=1e-5):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.vocab_size = vocab_size
|
| 94 |
+
self.max_seq_len = max_seq_len
|
| 95 |
+
self.type_size = type_size
|
| 96 |
+
self.pad_id = pad_id
|
| 97 |
+
self.dim = dim
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
self.num_layers = num_layers
|
| 100 |
+
self.post_norm = post_norm
|
| 101 |
+
self.eps = eps
|
| 102 |
+
|
| 103 |
+
# embeddings
|
| 104 |
+
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
|
| 105 |
+
self.type_embedding = nn.Embedding(type_size, dim)
|
| 106 |
+
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
|
| 107 |
+
self.dropout = nn.Dropout(dropout)
|
| 108 |
+
|
| 109 |
+
# blocks
|
| 110 |
+
self.blocks = nn.ModuleList([
|
| 111 |
+
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
|
| 112 |
+
for _ in range(num_layers)
|
| 113 |
+
])
|
| 114 |
+
|
| 115 |
+
# norm layer
|
| 116 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 117 |
+
|
| 118 |
+
def forward(self, ids):
|
| 119 |
+
"""
|
| 120 |
+
ids: [B, L] of torch.LongTensor.
|
| 121 |
+
"""
|
| 122 |
+
b, s = ids.shape
|
| 123 |
+
mask = ids.ne(self.pad_id).long()
|
| 124 |
+
|
| 125 |
+
# embeddings
|
| 126 |
+
x = self.token_embedding(ids) + \
|
| 127 |
+
self.type_embedding(torch.zeros_like(ids)) + \
|
| 128 |
+
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
|
| 129 |
+
if self.post_norm:
|
| 130 |
+
x = self.norm(x)
|
| 131 |
+
x = self.dropout(x)
|
| 132 |
+
|
| 133 |
+
# blocks
|
| 134 |
+
mask = torch.where(
|
| 135 |
+
mask.view(b, 1, 1, s).gt(0), 0.0,
|
| 136 |
+
torch.finfo(x.dtype).min)
|
| 137 |
+
for block in self.blocks:
|
| 138 |
+
x = block(x, mask)
|
| 139 |
+
|
| 140 |
+
# output
|
| 141 |
+
if not self.post_norm:
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def xlm_roberta_large(pretrained=False,
|
| 147 |
+
return_tokenizer=False,
|
| 148 |
+
device='cpu',
|
| 149 |
+
**kwargs):
|
| 150 |
+
"""
|
| 151 |
+
XLMRobertaLarge adapted from Huggingface.
|
| 152 |
+
"""
|
| 153 |
+
# params
|
| 154 |
+
cfg = dict(
|
| 155 |
+
vocab_size=250002,
|
| 156 |
+
max_seq_len=514,
|
| 157 |
+
type_size=1,
|
| 158 |
+
pad_id=1,
|
| 159 |
+
dim=1024,
|
| 160 |
+
num_heads=16,
|
| 161 |
+
num_layers=24,
|
| 162 |
+
post_norm=True,
|
| 163 |
+
dropout=0.1,
|
| 164 |
+
eps=1e-5)
|
| 165 |
+
cfg.update(**kwargs)
|
| 166 |
+
|
| 167 |
+
# init a model on device
|
| 168 |
+
with torch.device(device):
|
| 169 |
+
model = XLMRoberta(**cfg)
|
| 170 |
+
return model
|
humo/utils/audio_processor_whisper.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=C0301
|
| 2 |
+
'''
|
| 3 |
+
This module contains the AudioProcessor class and related functions for processing audio data.
|
| 4 |
+
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
|
| 5 |
+
and audio separation. The class is initialized with configuration parameters and can process
|
| 6 |
+
audio files using the provided models.
|
| 7 |
+
'''
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
|
| 11 |
+
import librosa
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from audio_separator.separator import Separator
|
| 15 |
+
from transformers import WhisperModel, AutoFeatureExtractor
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
|
| 20 |
+
features = features.transpose(1, 2) # [1, C, T]
|
| 21 |
+
seq_len = features.shape[2] / float(input_fps)
|
| 22 |
+
if output_len is None:
|
| 23 |
+
output_len = int(seq_len * output_fps)
|
| 24 |
+
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
|
| 25 |
+
return output_features.transpose(1, 2)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
|
| 29 |
+
p = subprocess.Popen([
|
| 30 |
+
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
|
| 31 |
+
])
|
| 32 |
+
ret = p.wait()
|
| 33 |
+
assert ret == 0, "Resample audio failed!"
|
| 34 |
+
return output_audio_file
|
| 35 |
+
|
| 36 |
+
class AudioProcessor:
|
| 37 |
+
"""
|
| 38 |
+
AudioProcessor is a class that handles the processing of audio files.
|
| 39 |
+
It takes care of preprocessing the audio files, extracting features
|
| 40 |
+
using wav2vec models, and separating audio signals if needed.
|
| 41 |
+
|
| 42 |
+
:param sample_rate: Sampling rate of the audio file
|
| 43 |
+
:param fps: Frames per second for the extracted features
|
| 44 |
+
:param wav2vec_model_path: Path to the wav2vec model
|
| 45 |
+
:param only_last_features: Whether to only use the last features
|
| 46 |
+
:param audio_separator_model_path: Path to the audio separator model
|
| 47 |
+
:param audio_separator_model_name: Name of the audio separator model
|
| 48 |
+
:param cache_dir: Directory to cache the intermediate results
|
| 49 |
+
:param device: Device to run the processing on
|
| 50 |
+
"""
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
sample_rate,
|
| 54 |
+
fps,
|
| 55 |
+
wav2vec_model_path,
|
| 56 |
+
wav2vec_feature_type,
|
| 57 |
+
audio_separator_model_path:str=None,
|
| 58 |
+
audio_separator_model_name:str=None,
|
| 59 |
+
cache_dir:str='',
|
| 60 |
+
device="cuda:0",
|
| 61 |
+
) -> None:
|
| 62 |
+
self.sample_rate = sample_rate
|
| 63 |
+
self.fps = fps
|
| 64 |
+
self.device = device
|
| 65 |
+
|
| 66 |
+
self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
|
| 67 |
+
self.whisper.requires_grad_(False)
|
| 68 |
+
self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
|
| 69 |
+
|
| 70 |
+
if audio_separator_model_name is not None:
|
| 71 |
+
try:
|
| 72 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 73 |
+
except OSError as _:
|
| 74 |
+
print("Fail to create the output cache dir.")
|
| 75 |
+
self.audio_separator = Separator(
|
| 76 |
+
output_dir=cache_dir,
|
| 77 |
+
output_single_stem="vocals",
|
| 78 |
+
model_file_dir=audio_separator_model_path,
|
| 79 |
+
)
|
| 80 |
+
self.audio_separator.load_model(audio_separator_model_name)
|
| 81 |
+
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
|
| 82 |
+
else:
|
| 83 |
+
self.audio_separator=None
|
| 84 |
+
print("Use audio directly without vocals seperator.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_audio_feature(self, audio_path):
|
| 88 |
+
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
|
| 89 |
+
assert sampling_rate == 16000
|
| 90 |
+
|
| 91 |
+
audio_features = []
|
| 92 |
+
window = 750*640
|
| 93 |
+
for i in range(0, len(audio_input), window):
|
| 94 |
+
audio_feature = self.feature_extractor(audio_input[i:i+window],
|
| 95 |
+
sampling_rate=sampling_rate,
|
| 96 |
+
return_tensors="pt",
|
| 97 |
+
).input_features
|
| 98 |
+
audio_features.append(audio_feature)
|
| 99 |
+
audio_features = torch.cat(audio_features, dim=-1)
|
| 100 |
+
return audio_features, len(audio_input) // 640
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def preprocess(self, audio_path: str):
|
| 104 |
+
audio_input, audio_len = self.get_audio_feature(audio_path)
|
| 105 |
+
audio_feature = audio_input.to(self.whisper.device).float()
|
| 106 |
+
window = 3000
|
| 107 |
+
audio_prompts = []
|
| 108 |
+
for i in range(0, audio_feature.shape[-1], window):
|
| 109 |
+
audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
|
| 110 |
+
audio_prompt = torch.stack(audio_prompt, dim=2)
|
| 111 |
+
audio_prompts.append(audio_prompt)
|
| 112 |
+
|
| 113 |
+
audio_prompts = torch.cat(audio_prompts, dim=1)
|
| 114 |
+
audio_prompts = audio_prompts[:,:audio_len*2]
|
| 115 |
+
|
| 116 |
+
audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
|
| 117 |
+
|
| 118 |
+
return audio_emb, audio_emb.shape[0]
|
| 119 |
+
|
| 120 |
+
def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
|
| 121 |
+
if wav_enc_type == "wav2vec":
|
| 122 |
+
feat_merge = audio_emb
|
| 123 |
+
elif wav_enc_type == "whisper":
|
| 124 |
+
# [1, T, 33, 1280]
|
| 125 |
+
feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
|
| 126 |
+
feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
|
| 127 |
+
feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
|
| 128 |
+
feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
|
| 129 |
+
feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
|
| 130 |
+
feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
|
| 133 |
+
|
| 134 |
+
return feat_merge
|
| 135 |
+
|
| 136 |
+
def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
|
| 137 |
+
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
|
| 138 |
+
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
|
| 139 |
+
iter_ = 1 + (frame_num - 1) // 4
|
| 140 |
+
audio_emb_wind = []
|
| 141 |
+
for lt_i in range(iter_):
|
| 142 |
+
if lt_i == 0: # latent_i
|
| 143 |
+
# 提取第一帧VAElatent,audio左侧补0,标识出
|
| 144 |
+
st = frame0_idx + lt_i - 2
|
| 145 |
+
ed = frame0_idx + lt_i + 3
|
| 146 |
+
wind_feat = torch.stack([
|
| 147 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 148 |
+
for i in range(st, ed)
|
| 149 |
+
], dim=0) # [5, 13, 768]
|
| 150 |
+
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768]
|
| 151 |
+
else:
|
| 152 |
+
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
|
| 153 |
+
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
|
| 154 |
+
wind_feat = torch.stack([
|
| 155 |
+
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
|
| 156 |
+
for i in range(st, ed)
|
| 157 |
+
], dim=0) # [8, 13, 768]
|
| 158 |
+
audio_emb_wind.append(wind_feat)
|
| 159 |
+
audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768]
|
| 160 |
+
|
| 161 |
+
return audio_emb_wind, ed - audio_shift
|
| 162 |
+
|
| 163 |
+
def close(self):
|
| 164 |
+
"""
|
| 165 |
+
TODO: to be implemented
|
| 166 |
+
"""
|
| 167 |
+
return self
|
| 168 |
+
|
| 169 |
+
def __enter__(self):
|
| 170 |
+
return self
|
| 171 |
+
|
| 172 |
+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
|
| 173 |
+
self.close()
|
humo/utils/wav2vec.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=R0901
|
| 2 |
+
# src/models/wav2vec.py
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
This module defines the Wav2Vec model, which is a pre-trained model for speech recognition and understanding.
|
| 6 |
+
It inherits from the Wav2Vec2Model class in the transformers library and provides additional functionalities
|
| 7 |
+
such as feature extraction and encoding.
|
| 8 |
+
|
| 9 |
+
Classes:
|
| 10 |
+
Wav2VecModel: Inherits from Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
| 11 |
+
|
| 12 |
+
Functions:
|
| 13 |
+
linear_interpolation: Interpolates the features based on the sequence length.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from transformers import Wav2Vec2Model
|
| 18 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Wav2VecModel(Wav2Vec2Model):
|
| 22 |
+
"""
|
| 23 |
+
Wav2VecModel is a custom model class that extends the Wav2Vec2Model class from the transformers library.
|
| 24 |
+
It inherits all the functionality of the Wav2Vec2Model and adds additional methods for feature extraction and encoding.
|
| 25 |
+
...
|
| 26 |
+
|
| 27 |
+
Attributes:
|
| 28 |
+
base_model (Wav2Vec2Model): The base Wav2Vec2Model object.
|
| 29 |
+
|
| 30 |
+
Methods:
|
| 31 |
+
forward(input_values, seq_len, attention_mask=None, mask_time_indices=None
|
| 32 |
+
, output_attentions=None, output_hidden_states=None, return_dict=None):
|
| 33 |
+
Forward pass of the Wav2VecModel.
|
| 34 |
+
It takes input_values, seq_len, and other optional parameters as input and returns the output of the base model.
|
| 35 |
+
|
| 36 |
+
feature_extract(input_values, seq_len):
|
| 37 |
+
Extracts features from the input_values using the base model.
|
| 38 |
+
|
| 39 |
+
encode(extract_features, attention_mask=None, mask_time_indices=None, output_attentions=None, output_hidden_states=None, return_dict=None):
|
| 40 |
+
Encodes the extracted features using the base model and returns the encoded features.
|
| 41 |
+
"""
|
| 42 |
+
def forward(
|
| 43 |
+
self,
|
| 44 |
+
input_values,
|
| 45 |
+
seq_len,
|
| 46 |
+
attention_mask=None,
|
| 47 |
+
mask_time_indices=None,
|
| 48 |
+
output_attentions=None,
|
| 49 |
+
output_hidden_states=None,
|
| 50 |
+
return_dict=None,
|
| 51 |
+
):
|
| 52 |
+
"""
|
| 53 |
+
Forward pass of the Wav2Vec model.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
self: The instance of the model.
|
| 57 |
+
input_values: The input values (waveform) to the model.
|
| 58 |
+
seq_len: The sequence length of the input values.
|
| 59 |
+
attention_mask: Attention mask to be used for the model.
|
| 60 |
+
mask_time_indices: Mask indices to be used for the model.
|
| 61 |
+
output_attentions: If set to True, returns attentions.
|
| 62 |
+
output_hidden_states: If set to True, returns hidden states.
|
| 63 |
+
return_dict: If set to True, returns a BaseModelOutput instead of a tuple.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
The output of the Wav2Vec model.
|
| 67 |
+
"""
|
| 68 |
+
self.config.output_attentions = True
|
| 69 |
+
|
| 70 |
+
output_hidden_states = (
|
| 71 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 72 |
+
)
|
| 73 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 74 |
+
|
| 75 |
+
extract_features = self.feature_extractor(input_values)
|
| 76 |
+
extract_features = extract_features.transpose(1, 2)
|
| 77 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 78 |
+
|
| 79 |
+
if attention_mask is not None:
|
| 80 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 81 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 82 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 86 |
+
hidden_states = self._mask_hidden_states(
|
| 87 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
encoder_outputs = self.encoder(
|
| 91 |
+
hidden_states,
|
| 92 |
+
attention_mask=attention_mask,
|
| 93 |
+
output_attentions=output_attentions,
|
| 94 |
+
output_hidden_states=output_hidden_states,
|
| 95 |
+
return_dict=return_dict,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
hidden_states = encoder_outputs[0]
|
| 99 |
+
|
| 100 |
+
if self.adapter is not None:
|
| 101 |
+
hidden_states = self.adapter(hidden_states)
|
| 102 |
+
|
| 103 |
+
if not return_dict:
|
| 104 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 105 |
+
return BaseModelOutput(
|
| 106 |
+
last_hidden_state=hidden_states,
|
| 107 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 108 |
+
attentions=encoder_outputs.attentions,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def feature_extract(
|
| 113 |
+
self,
|
| 114 |
+
input_values,
|
| 115 |
+
seq_len,
|
| 116 |
+
):
|
| 117 |
+
"""
|
| 118 |
+
Extracts features from the input values and returns the extracted features.
|
| 119 |
+
|
| 120 |
+
Parameters:
|
| 121 |
+
input_values (torch.Tensor): The input values to be processed.
|
| 122 |
+
seq_len (torch.Tensor): The sequence lengths of the input values.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
extracted_features (torch.Tensor): The extracted features from the input values.
|
| 126 |
+
"""
|
| 127 |
+
extract_features = self.feature_extractor(input_values)
|
| 128 |
+
extract_features = extract_features.transpose(1, 2)
|
| 129 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 130 |
+
|
| 131 |
+
return extract_features
|
| 132 |
+
|
| 133 |
+
def encode(
|
| 134 |
+
self,
|
| 135 |
+
extract_features,
|
| 136 |
+
attention_mask=None,
|
| 137 |
+
mask_time_indices=None,
|
| 138 |
+
output_attentions=None,
|
| 139 |
+
output_hidden_states=None,
|
| 140 |
+
return_dict=None,
|
| 141 |
+
):
|
| 142 |
+
"""
|
| 143 |
+
Encodes the input features into the output space.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
extract_features (torch.Tensor): The extracted features from the audio signal.
|
| 147 |
+
attention_mask (torch.Tensor, optional): Attention mask to be used for padding.
|
| 148 |
+
mask_time_indices (torch.Tensor, optional): Masked indices for the time dimension.
|
| 149 |
+
output_attentions (bool, optional): If set to True, returns the attention weights.
|
| 150 |
+
output_hidden_states (bool, optional): If set to True, returns all hidden states.
|
| 151 |
+
return_dict (bool, optional): If set to True, returns a BaseModelOutput instead of the tuple.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
The encoded output features.
|
| 155 |
+
"""
|
| 156 |
+
self.config.output_attentions = True
|
| 157 |
+
|
| 158 |
+
output_hidden_states = (
|
| 159 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 160 |
+
)
|
| 161 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 162 |
+
|
| 163 |
+
if attention_mask is not None:
|
| 164 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 165 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 166 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 170 |
+
hidden_states = self._mask_hidden_states(
|
| 171 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
encoder_outputs = self.encoder(
|
| 175 |
+
hidden_states,
|
| 176 |
+
attention_mask=attention_mask,
|
| 177 |
+
output_attentions=output_attentions,
|
| 178 |
+
output_hidden_states=output_hidden_states,
|
| 179 |
+
return_dict=return_dict,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
hidden_states = encoder_outputs[0]
|
| 183 |
+
|
| 184 |
+
if self.adapter is not None:
|
| 185 |
+
hidden_states = self.adapter(hidden_states)
|
| 186 |
+
|
| 187 |
+
if not return_dict:
|
| 188 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 189 |
+
return BaseModelOutput(
|
| 190 |
+
last_hidden_state=hidden_states,
|
| 191 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 192 |
+
attentions=encoder_outputs.attentions,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def linear_interpolation(features, seq_len):
|
| 197 |
+
"""
|
| 198 |
+
Transpose the features to interpolate linearly.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
features (torch.Tensor): The extracted features to be interpolated.
|
| 202 |
+
seq_len (torch.Tensor): The sequence lengths of the features.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
torch.Tensor: The interpolated features.
|
| 206 |
+
"""
|
| 207 |
+
features = features.transpose(1, 2)
|
| 208 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
| 209 |
+
return output_features.transpose(1, 2)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
|
| 213 |
+
features = features.transpose(1, 2) # [1, C, T]
|
| 214 |
+
seq_len = features.shape[2] / float(input_fps)
|
| 215 |
+
if output_len is None:
|
| 216 |
+
output_len = int(seq_len * output_fps)
|
| 217 |
+
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
|
| 218 |
+
return output_features.transpose(1, 2)
|
main.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 6 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 9 |
+
# See the License for the specific language governing permissions and
|
| 10 |
+
# limitations under the License.
|
| 11 |
+
|
| 12 |
+
# Inference codes adapted from [SeedVR]
|
| 13 |
+
# https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
|
| 14 |
+
|
| 15 |
+
from sys import argv
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
path_to_insert = "humo"
|
| 19 |
+
if path_to_insert not in sys.path:
|
| 20 |
+
sys.path.insert(0, path_to_insert)
|
| 21 |
+
|
| 22 |
+
from common.config import load_config, create_object
|
| 23 |
+
|
| 24 |
+
# Load config.
|
| 25 |
+
config = load_config(argv[1], argv[2:])
|
| 26 |
+
|
| 27 |
+
runner = create_object(config)
|
| 28 |
+
runner.entrypoint()
|