Upload 24 files
Browse files- LICENSE.md +201 -0
- MANUAL_DE_USO.md +287 -0
- README-ja.md +177 -0
- README.md +609 -12
- XTI_hijack.py +204 -0
- _typos.toml +35 -0
- app.py +698 -0
- fine_tune.py +538 -0
- gen_img.py +0 -0
- gen_img_diffusers.py +0 -0
- requirements.txt +42 -0
- sdxl_gen_img.py +0 -0
- sdxl_minimal_inference.py +345 -0
- sdxl_train.py +952 -0
- sdxl_train_control_net_lllite.py +626 -0
- sdxl_train_control_net_lllite_old.py +586 -0
- sdxl_train_network.py +184 -0
- sdxl_train_textual_inversion.py +138 -0
- setup.py +3 -0
- train_controlnet.py +648 -0
- train_db.py +531 -0
- train_network.py +1250 -0
- train_textual_inversion.py +813 -0
- train_textual_inversion_XTI.py +720 -0
LICENSE.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [2022] [kohya-ss]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
MANUAL_DE_USO.md
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 📖 Manual de Uso - LoRA Trainer Funcional
|
| 2 |
+
|
| 3 |
+
## 🎯 Visão Geral
|
| 4 |
+
|
| 5 |
+
Este LoRA Trainer é uma ferramenta **100% funcional** baseada no kohya-ss sd-scripts que permite treinar modelos LoRA reais para Stable Diffusion. A ferramenta foi desenvolvida especificamente para funcionar no Hugging Face Spaces e oferece todas as funcionalidades necessárias para um treinamento profissional.
|
| 6 |
+
|
| 7 |
+
## 🚀 Início Rápido
|
| 8 |
+
|
| 9 |
+
### Passo 1: Instalação das Dependências
|
| 10 |
+
1. Acesse a aba "🔧 Instalação"
|
| 11 |
+
2. Clique em "📦 Instalar Dependências"
|
| 12 |
+
3. Aguarde a instalação completa (pode levar alguns minutos)
|
| 13 |
+
|
| 14 |
+
### Passo 2: Configuração do Projeto
|
| 15 |
+
1. Vá para a aba "📁 Configuração do Projeto"
|
| 16 |
+
2. Digite um nome único para seu projeto (ex: "meu_lora_anime")
|
| 17 |
+
3. Escolha um modelo base ou insira uma URL personalizada
|
| 18 |
+
4. Clique em "📥 Baixar Modelo"
|
| 19 |
+
|
| 20 |
+
### Passo 3: Preparação do Dataset
|
| 21 |
+
1. Organize suas imagens em uma pasta local
|
| 22 |
+
2. Para cada imagem, crie um arquivo .txt com o mesmo nome
|
| 23 |
+
3. Compacte tudo em um arquivo ZIP
|
| 24 |
+
4. Faça upload na seção "📊 Upload do Dataset"
|
| 25 |
+
5. Clique em "📊 Processar Dataset"
|
| 26 |
+
|
| 27 |
+
### Passo 4: Configuração dos Parâmetros
|
| 28 |
+
1. Acesse a aba "⚙️ Parâmetros de Treinamento"
|
| 29 |
+
2. Ajuste os parâmetros conforme sua necessidade
|
| 30 |
+
3. Use as configurações recomendadas como ponto de partida
|
| 31 |
+
|
| 32 |
+
### Passo 5: Treinamento
|
| 33 |
+
1. Vá para a aba "🚀 Treinamento"
|
| 34 |
+
2. Clique em "📝 Criar Configuração de Treinamento"
|
| 35 |
+
3. Clique em "🎯 Iniciar Treinamento"
|
| 36 |
+
4. Acompanhe o progresso em tempo real
|
| 37 |
+
|
| 38 |
+
### Passo 6: Download dos Resultados
|
| 39 |
+
1. Acesse a aba "📥 Download dos Resultados"
|
| 40 |
+
2. Clique em "🔄 Atualizar Lista de Arquivos"
|
| 41 |
+
3. Selecione e baixe seu LoRA treinado
|
| 42 |
+
|
| 43 |
+
## 📋 Requisitos do Sistema
|
| 44 |
+
|
| 45 |
+
### Mínimos
|
| 46 |
+
- **GPU**: NVIDIA com 6GB VRAM
|
| 47 |
+
- **RAM**: 8GB
|
| 48 |
+
- **Espaço**: 5GB livres
|
| 49 |
+
|
| 50 |
+
### Recomendados
|
| 51 |
+
- **GPU**: NVIDIA com 12GB+ VRAM
|
| 52 |
+
- **RAM**: 16GB+
|
| 53 |
+
- **Espaço**: 20GB+ livres
|
| 54 |
+
|
| 55 |
+
## 🎨 Preparação do Dataset
|
| 56 |
+
|
| 57 |
+
### Estrutura Recomendada
|
| 58 |
+
```
|
| 59 |
+
meu_dataset/
|
| 60 |
+
├── imagem001.jpg
|
| 61 |
+
├── imagem001.txt
|
| 62 |
+
├── imagem002.png
|
| 63 |
+
├── imagem002.txt
|
| 64 |
+
├── imagem003.webp
|
| 65 |
+
├── imagem003.txt
|
| 66 |
+
└── ...
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Formatos Suportados
|
| 70 |
+
- **Imagens**: JPG, PNG, WEBP, BMP, TIFF
|
| 71 |
+
- **Captions**: TXT (UTF-8)
|
| 72 |
+
|
| 73 |
+
### Exemplo de Caption
|
| 74 |
+
```
|
| 75 |
+
1girl, long hair, blue eyes, school uniform, smile, outdoors, cherry blossoms, anime style, high quality
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Dicas para Captions
|
| 79 |
+
- Use vírgulas para separar tags
|
| 80 |
+
- Coloque tags importantes no início
|
| 81 |
+
- Seja específico mas conciso
|
| 82 |
+
- Use tags consistentes em todo o dataset
|
| 83 |
+
|
| 84 |
+
## ⚙️ Configuração de Parâmetros
|
| 85 |
+
|
| 86 |
+
### Parâmetros Básicos
|
| 87 |
+
|
| 88 |
+
#### Resolução
|
| 89 |
+
- **512px**: Padrão, mais rápido, menor uso de memória
|
| 90 |
+
- **768px**: Melhor qualidade, moderado
|
| 91 |
+
- **1024px**: Máxima qualidade, mais lento
|
| 92 |
+
|
| 93 |
+
#### Batch Size
|
| 94 |
+
- **1**: Menor uso de memória, mais lento
|
| 95 |
+
- **2-4**: Equilibrado (recomendado)
|
| 96 |
+
- **8+**: Apenas para GPUs potentes
|
| 97 |
+
|
| 98 |
+
#### Épocas
|
| 99 |
+
- **5-10**: Para datasets grandes (50+ imagens)
|
| 100 |
+
- **10-20**: Para datasets médios (20-50 imagens)
|
| 101 |
+
- **20-30**: Para datasets pequenos (10-20 imagens)
|
| 102 |
+
|
| 103 |
+
### Parâmetros Avançados
|
| 104 |
+
|
| 105 |
+
#### Learning Rate
|
| 106 |
+
- **1e-3**: Muito alto, pode causar instabilidade
|
| 107 |
+
- **5e-4**: Padrão, bom para a maioria dos casos
|
| 108 |
+
- **1e-4**: Conservador, para datasets grandes
|
| 109 |
+
- **5e-5**: Muito baixo, treinamento lento
|
| 110 |
+
|
| 111 |
+
#### Network Dimension
|
| 112 |
+
- **8-16**: LoRAs pequenos, menos detalhes
|
| 113 |
+
- **32**: Padrão, bom equilíbrio
|
| 114 |
+
- **64-128**: Mais detalhes, arquivos maiores
|
| 115 |
+
|
| 116 |
+
#### Network Alpha
|
| 117 |
+
- Geralmente metade do Network Dimension
|
| 118 |
+
- Controla a força do LoRA
|
| 119 |
+
- Valores menores = efeito mais sutil
|
| 120 |
+
|
| 121 |
+
### Tipos de LoRA
|
| 122 |
+
|
| 123 |
+
#### LoRA Clássico
|
| 124 |
+
- Menor tamanho de arquivo
|
| 125 |
+
- Bom para uso geral
|
| 126 |
+
- Mais rápido para treinar
|
| 127 |
+
|
| 128 |
+
#### LoCon
|
| 129 |
+
- Melhor para estilos artísticos
|
| 130 |
+
- Mais camadas de aprendizado
|
| 131 |
+
- Arquivos maiores
|
| 132 |
+
|
| 133 |
+
## 🎯 Configurações por Tipo de Projeto
|
| 134 |
+
|
| 135 |
+
### Para Personagens/Pessoas
|
| 136 |
+
```
|
| 137 |
+
Imagens: 15-30 variadas
|
| 138 |
+
Network Dim: 32
|
| 139 |
+
Network Alpha: 16
|
| 140 |
+
Learning Rate: 1e-4
|
| 141 |
+
Épocas: 10-15
|
| 142 |
+
Batch Size: 2
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Para Estilos Artísticos
|
| 146 |
+
```
|
| 147 |
+
Imagens: 30-50 do estilo
|
| 148 |
+
Tipo: LoCon
|
| 149 |
+
Network Dim: 64
|
| 150 |
+
Network Alpha: 32
|
| 151 |
+
Learning Rate: 5e-5
|
| 152 |
+
Épocas: 15-25
|
| 153 |
+
Batch Size: 1-2
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Para Objetos/Conceitos
|
| 157 |
+
```
|
| 158 |
+
Imagens: 10-25
|
| 159 |
+
Network Dim: 16
|
| 160 |
+
Network Alpha: 8
|
| 161 |
+
Learning Rate: 5e-4
|
| 162 |
+
Épocas: 8-12
|
| 163 |
+
Batch Size: 2-4
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
## 🔧 Solução de Problemas
|
| 167 |
+
|
| 168 |
+
### Erro de Memória (CUDA OOM)
|
| 169 |
+
**Sintomas**: "CUDA out of memory"
|
| 170 |
+
**Soluções**:
|
| 171 |
+
- Reduza o batch size para 1
|
| 172 |
+
- Diminua a resolução para 512px
|
| 173 |
+
- Use mixed precision fp16
|
| 174 |
+
|
| 175 |
+
### Treinamento Muito Lento
|
| 176 |
+
**Sintomas**: Progresso muito lento
|
| 177 |
+
**Soluções**:
|
| 178 |
+
- Aumente o batch size (se possível)
|
| 179 |
+
- Use resolução menor
|
| 180 |
+
- Verifique se xFormers está ativo
|
| 181 |
+
|
| 182 |
+
### Resultados Ruins/Overfitting
|
| 183 |
+
**Sintomas**: LoRA não funciona ou muito forte
|
| 184 |
+
**Soluções**:
|
| 185 |
+
- Reduza o learning rate
|
| 186 |
+
- Diminua o número de épocas
|
| 187 |
+
- Use mais imagens variadas
|
| 188 |
+
- Ajuste network alpha
|
| 189 |
+
|
| 190 |
+
### Erro de Configuração
|
| 191 |
+
**Sintomas**: Falha ao criar configuração
|
| 192 |
+
**Soluções**:
|
| 193 |
+
- Verifique se o modelo foi baixado
|
| 194 |
+
- Confirme que o dataset foi processado
|
| 195 |
+
- Reinicie a aplicação
|
| 196 |
+
|
| 197 |
+
## 📊 Monitoramento do Treinamento
|
| 198 |
+
|
| 199 |
+
### Métricas Importantes
|
| 200 |
+
- **Loss**: Deve diminuir gradualmente
|
| 201 |
+
- **Learning Rate**: Varia conforme scheduler
|
| 202 |
+
- **Tempo por Época**: Depende do hardware
|
| 203 |
+
|
| 204 |
+
### Sinais de Bom Treinamento
|
| 205 |
+
- Loss diminui consistentemente
|
| 206 |
+
- Sem erros de memória
|
| 207 |
+
- Progresso estável
|
| 208 |
+
|
| 209 |
+
### Sinais de Problemas
|
| 210 |
+
- Loss oscila muito
|
| 211 |
+
- Erros frequentes
|
| 212 |
+
- Progresso muito lento
|
| 213 |
+
|
| 214 |
+
## 💾 Gerenciamento de Arquivos
|
| 215 |
+
|
| 216 |
+
### Estrutura de Saída
|
| 217 |
+
```
|
| 218 |
+
/tmp/lora_training/projects/meu_projeto/
|
| 219 |
+
├── dataset/ # Imagens processadas
|
| 220 |
+
├── output/ # LoRAs gerados
|
| 221 |
+
├── logs/ # Logs do treinamento
|
| 222 |
+
├── dataset_config.toml
|
| 223 |
+
└── training_config.toml
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
### Arquivos Gerados
|
| 227 |
+
- **projeto_epoch_0001.safetensors**: LoRA da época 1
|
| 228 |
+
- **projeto_epoch_0010.safetensors**: LoRA da época 10
|
| 229 |
+
- **logs/**: Logs detalhados do TensorBoard
|
| 230 |
+
|
| 231 |
+
## 🎨 Uso dos LoRAs Treinados
|
| 232 |
+
|
| 233 |
+
### No Automatic1111
|
| 234 |
+
1. Copie o arquivo .safetensors para `models/Lora/`
|
| 235 |
+
2. Use na prompt: `<lora:nome_do_arquivo:0.8>`
|
| 236 |
+
3. Ajuste o peso (0.1 a 1.5)
|
| 237 |
+
|
| 238 |
+
### No ComfyUI
|
| 239 |
+
1. Coloque o arquivo em `models/loras/`
|
| 240 |
+
2. Use o nó "Load LoRA"
|
| 241 |
+
3. Conecte ao modelo
|
| 242 |
+
|
| 243 |
+
### Pesos Recomendados
|
| 244 |
+
- **0.3-0.6**: Efeito sutil
|
| 245 |
+
- **0.7-1.0**: Efeito padrão
|
| 246 |
+
- **1.1-1.5**: Efeito forte
|
| 247 |
+
|
| 248 |
+
## 🔄 Melhores Práticas
|
| 249 |
+
|
| 250 |
+
### Antes do Treinamento
|
| 251 |
+
1. **Qualidade sobre Quantidade**: 20 imagens boas > 100 ruins
|
| 252 |
+
2. **Variedade**: Use ângulos, poses e cenários diferentes
|
| 253 |
+
3. **Consistência**: Mantenha estilo consistente nas captions
|
| 254 |
+
4. **Backup**: Salve configurações que funcionaram
|
| 255 |
+
|
| 256 |
+
### Durante o Treinamento
|
| 257 |
+
1. **Monitore**: Acompanhe o progresso regularmente
|
| 258 |
+
2. **Paciência**: Não interrompa sem necessidade
|
| 259 |
+
3. **Recursos**: Monitore uso de GPU/RAM
|
| 260 |
+
|
| 261 |
+
### Após o Treinamento
|
| 262 |
+
1. **Teste**: Experimente diferentes pesos
|
| 263 |
+
2. **Compare**: Teste épocas diferentes
|
| 264 |
+
3. **Documente**: Anote configurações que funcionaram
|
| 265 |
+
4. **Compartilhe**: Considere compartilhar bons resultados
|
| 266 |
+
|
| 267 |
+
## 🆘 Suporte e Recursos
|
| 268 |
+
|
| 269 |
+
### Documentação Adicional
|
| 270 |
+
- [Guia Oficial kohya-ss](https://github.com/kohya-ss/sd-scripts)
|
| 271 |
+
- [Documentação Diffusers](https://huggingface.co/docs/diffusers)
|
| 272 |
+
- [Comunidade Stable Diffusion](https://discord.gg/stable-diffusion)
|
| 273 |
+
|
| 274 |
+
### Logs e Debug
|
| 275 |
+
- Verifique os logs em `/tmp/lora_training/projects/seu_projeto/logs/`
|
| 276 |
+
- Use TensorBoard para visualizar métricas
|
| 277 |
+
- Salve configurações que funcionaram bem
|
| 278 |
+
|
| 279 |
+
### Limitações Conhecidas
|
| 280 |
+
- Requer GPU NVIDIA com CUDA
|
| 281 |
+
- Modelos grandes podem precisar de mais memória
|
| 282 |
+
- Treinamento pode ser lento em hardware limitado
|
| 283 |
+
|
| 284 |
+
---
|
| 285 |
+
|
| 286 |
+
**Nota**: Esta ferramenta é para fins educacionais e de pesquisa. Use responsavelmente e respeite direitos autorais das imagens utilizadas.
|
| 287 |
+
|
README-ja.md
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## リポジトリについて
|
| 2 |
+
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
| 3 |
+
|
| 4 |
+
[README in English](./README.md) ←更新情報はこちらにあります
|
| 5 |
+
|
| 6 |
+
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
|
| 7 |
+
|
| 8 |
+
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
|
| 9 |
+
|
| 10 |
+
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
| 11 |
+
|
| 12 |
+
以下のスクリプトがあります。
|
| 13 |
+
|
| 14 |
+
* DreamBooth、U-NetおよびText Encoderの学習をサポート
|
| 15 |
+
* fine-tuning、同上
|
| 16 |
+
* LoRAの学習をサポート
|
| 17 |
+
* 画像生成
|
| 18 |
+
* モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
|
| 19 |
+
|
| 20 |
+
### スポンサー
|
| 21 |
+
|
| 22 |
+
このプロジェクトを支援してくださる企業・団体の皆様に深く感謝いたします。
|
| 23 |
+
|
| 24 |
+
<a href="https://aihub.co.jp/">
|
| 25 |
+
<img src="./images/logo_aihub.png" alt="AiHUB株式会社" title="AiHUB株式会社" height="100px">
|
| 26 |
+
</a>
|
| 27 |
+
|
| 28 |
+
### スポンサー募集のお知らせ
|
| 29 |
+
|
| 30 |
+
このプロジェクトがお役に立ったなら、ご支援いただけると嬉しく思います。 [GitHub Sponsors](https://github.com/sponsors/kohya-ss/)で受け付けています。
|
| 31 |
+
|
| 32 |
+
## 使用法について
|
| 33 |
+
|
| 34 |
+
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
|
| 35 |
+
* [データセット設定](./docs/config_README-ja.md)
|
| 36 |
+
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
|
| 37 |
+
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
|
| 38 |
+
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
|
| 39 |
+
* [LoRAの学習について](./docs/train_network_README-ja.md)
|
| 40 |
+
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
|
| 41 |
+
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
|
| 42 |
+
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
| 43 |
+
|
| 44 |
+
## Windowsでの動作に必要なプログラム
|
| 45 |
+
|
| 46 |
+
Python 3.10.6およびGitが必要です。
|
| 47 |
+
|
| 48 |
+
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
| 49 |
+
- git: https://git-scm.com/download/win
|
| 50 |
+
|
| 51 |
+
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
|
| 52 |
+
|
| 53 |
+
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
|
| 54 |
+
(venvに限らずスクリプトの実行が可能になりますので注意してください。)
|
| 55 |
+
|
| 56 |
+
- PowerShellを管理者として開きます。
|
| 57 |
+
- 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
|
| 58 |
+
- 管理者のPowerShellを閉じます。
|
| 59 |
+
|
| 60 |
+
## Windows環境でのインストール
|
| 61 |
+
|
| 62 |
+
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
|
| 63 |
+
|
| 64 |
+
(なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
|
| 65 |
+
|
| 66 |
+
PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
| 67 |
+
|
| 68 |
+
```powershell
|
| 69 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
| 70 |
+
cd sd-scripts
|
| 71 |
+
|
| 72 |
+
python -m venv venv
|
| 73 |
+
.\venv\Scripts\activate
|
| 74 |
+
|
| 75 |
+
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
| 76 |
+
pip install --upgrade -r requirements.txt
|
| 77 |
+
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
|
| 78 |
+
|
| 79 |
+
accelerate config
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
コマンドプロンプトでも同一です。
|
| 83 |
+
|
| 84 |
+
注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
|
| 85 |
+
|
| 86 |
+
この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
|
| 87 |
+
|
| 88 |
+
PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
|
| 89 |
+
|
| 90 |
+
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
| 91 |
+
|
| 92 |
+
```txt
|
| 93 |
+
- This machine
|
| 94 |
+
- No distributed training
|
| 95 |
+
- NO
|
| 96 |
+
- NO
|
| 97 |
+
- NO
|
| 98 |
+
- all
|
| 99 |
+
- fp16
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
| 103 |
+
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
| 104 |
+
|
| 105 |
+
## アップグレード
|
| 106 |
+
|
| 107 |
+
新しいリリースがあった場合、以下のコマンドで更新できます。
|
| 108 |
+
|
| 109 |
+
```powershell
|
| 110 |
+
cd sd-scripts
|
| 111 |
+
git pull
|
| 112 |
+
.\venv\Scripts\activate
|
| 113 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
コマンドが成功すれば新しいバージョンが使用できます。
|
| 117 |
+
|
| 118 |
+
## 謝意
|
| 119 |
+
|
| 120 |
+
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
|
| 121 |
+
|
| 122 |
+
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
| 123 |
+
|
| 124 |
+
## ライセンス
|
| 125 |
+
|
| 126 |
+
スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
|
| 127 |
+
|
| 128 |
+
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
| 129 |
+
|
| 130 |
+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
| 131 |
+
|
| 132 |
+
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
| 133 |
+
|
| 134 |
+
## その他の情報
|
| 135 |
+
|
| 136 |
+
### LoRAの名称について
|
| 137 |
+
|
| 138 |
+
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
|
| 139 |
+
|
| 140 |
+
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
|
| 141 |
+
|
| 142 |
+
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
|
| 143 |
+
|
| 144 |
+
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
|
| 145 |
+
|
| 146 |
+
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
|
| 147 |
+
|
| 148 |
+
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
|
| 149 |
+
|
| 150 |
+
<!--
|
| 151 |
+
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
| 152 |
+
|
| 153 |
+
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
|
| 154 |
+
-->
|
| 155 |
+
|
| 156 |
+
### 学習中のサンプル画像生成
|
| 157 |
+
|
| 158 |
+
プロンプトファイルは例えば以下のようになります。
|
| 159 |
+
|
| 160 |
+
```
|
| 161 |
+
# prompt 1
|
| 162 |
+
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
| 163 |
+
|
| 164 |
+
# prompt 2
|
| 165 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
| 169 |
+
|
| 170 |
+
* `--n` Negative prompt up to the next option.
|
| 171 |
+
* `--w` Specifies the width of the generated image.
|
| 172 |
+
* `--h` Specifies the height of the generated image.
|
| 173 |
+
* `--d` Specifies the seed of the generated image.
|
| 174 |
+
* `--l` Specifies the CFG scale of the generated image.
|
| 175 |
+
* `--s` Specifies the number of steps in the generation.
|
| 176 |
+
|
| 177 |
+
`( )` や `[ ]` などの重みづけも動作します。
|
README.md
CHANGED
|
@@ -1,12 +1,609 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
This repository contains training, generation and utility scripts for Stable Diffusion.
|
| 2 |
+
|
| 3 |
+
[__Change History__](#change-history) is moved to the bottom of the page.
|
| 4 |
+
更新履歴は[ページ末尾](#change-history)に移しました。
|
| 5 |
+
|
| 6 |
+
Latest update: 2025-03-21 (Version 0.9.1)
|
| 7 |
+
|
| 8 |
+
[日本語版READMEはこちら](./README-ja.md)
|
| 9 |
+
|
| 10 |
+
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
| 11 |
+
|
| 12 |
+
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
| 16 |
+
|
| 17 |
+
This repository contains the scripts for:
|
| 18 |
+
|
| 19 |
+
* DreamBooth training, including U-Net and Text Encoder
|
| 20 |
+
* Fine-tuning (native training), including U-Net and Text Encoder
|
| 21 |
+
* LoRA training
|
| 22 |
+
* Textual Inversion training
|
| 23 |
+
* Image generation
|
| 24 |
+
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
| 25 |
+
|
| 26 |
+
### Sponsors
|
| 27 |
+
|
| 28 |
+
We are grateful to the following companies for their generous sponsorship:
|
| 29 |
+
|
| 30 |
+
<a href="https://aihub.co.jp/top-en">
|
| 31 |
+
<img src="./images/logo_aihub.png" alt="AiHUB Inc." title="AiHUB Inc." height="100px">
|
| 32 |
+
</a>
|
| 33 |
+
|
| 34 |
+
### Support the Project
|
| 35 |
+
|
| 36 |
+
If you find this project helpful, please consider supporting its development via [GitHub Sponsors](https://github.com/sponsors/kohya-ss/). Your support is greatly appreciated!
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## About requirements.txt
|
| 40 |
+
|
| 41 |
+
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
|
| 42 |
+
|
| 43 |
+
The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
|
| 44 |
+
|
| 45 |
+
## Links to usage documentation
|
| 46 |
+
|
| 47 |
+
Most of the documents are written in Japanese.
|
| 48 |
+
|
| 49 |
+
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
|
| 50 |
+
|
| 51 |
+
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
|
| 52 |
+
* [Chinese version](./docs/train_README-zh.md)
|
| 53 |
+
* [SDXL training](./docs/train_SDXL-en.md) (English version)
|
| 54 |
+
* [Dataset config](./docs/config_README-ja.md)
|
| 55 |
+
* [English version](./docs/config_README-en.md)
|
| 56 |
+
* [DreamBooth training guide](./docs/train_db_README-ja.md)
|
| 57 |
+
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
|
| 58 |
+
* [Training LoRA](./docs/train_network_README-ja.md)
|
| 59 |
+
* [Training Textual Inversion](./docs/train_ti_README-ja.md)
|
| 60 |
+
* [Image generation](./docs/gen_img_README-ja.md)
|
| 61 |
+
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
| 62 |
+
|
| 63 |
+
## Windows Required Dependencies
|
| 64 |
+
|
| 65 |
+
Python 3.10.6 and Git:
|
| 66 |
+
|
| 67 |
+
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
|
| 68 |
+
- git: https://git-scm.com/download/win
|
| 69 |
+
|
| 70 |
+
Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
|
| 71 |
+
|
| 72 |
+
Give unrestricted script access to powershell so venv can work:
|
| 73 |
+
|
| 74 |
+
- Open an administrator powershell window
|
| 75 |
+
- Type `Set-ExecutionPolicy Unrestricted` and answer A
|
| 76 |
+
- Close admin powershell window
|
| 77 |
+
|
| 78 |
+
## Windows Installation
|
| 79 |
+
|
| 80 |
+
Open a regular Powershell terminal and type the following inside:
|
| 81 |
+
|
| 82 |
+
```powershell
|
| 83 |
+
git clone https://github.com/kohya-ss/sd-scripts.git
|
| 84 |
+
cd sd-scripts
|
| 85 |
+
|
| 86 |
+
python -m venv venv
|
| 87 |
+
.\venv\Scripts\activate
|
| 88 |
+
|
| 89 |
+
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
|
| 90 |
+
pip install --upgrade -r requirements.txt
|
| 91 |
+
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
|
| 92 |
+
|
| 93 |
+
accelerate config
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
If `python -m venv` shows only `python`, change `python` to `py`.
|
| 97 |
+
|
| 98 |
+
Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
|
| 99 |
+
|
| 100 |
+
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
|
| 101 |
+
|
| 102 |
+
If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
|
| 103 |
+
|
| 104 |
+
<!--
|
| 105 |
+
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
|
| 106 |
+
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
|
| 107 |
+
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
| 108 |
+
-->
|
| 109 |
+
Answers to accelerate config:
|
| 110 |
+
|
| 111 |
+
```txt
|
| 112 |
+
- This machine
|
| 113 |
+
- No distributed training
|
| 114 |
+
- NO
|
| 115 |
+
- NO
|
| 116 |
+
- NO
|
| 117 |
+
- all
|
| 118 |
+
- fp16
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
If you'd like to use bf16, please answer `bf16` to the last question.
|
| 122 |
+
|
| 123 |
+
Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
|
| 124 |
+
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
|
| 125 |
+
|
| 126 |
+
(Single GPU with id `0` will be used.)
|
| 127 |
+
|
| 128 |
+
## Upgrade
|
| 129 |
+
|
| 130 |
+
When a new release comes out you can upgrade your repo with the following command:
|
| 131 |
+
|
| 132 |
+
```powershell
|
| 133 |
+
cd sd-scripts
|
| 134 |
+
git pull
|
| 135 |
+
.\venv\Scripts\activate
|
| 136 |
+
pip install --use-pep517 --upgrade -r requirements.txt
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
Once the commands have completed successfully you should be ready to use the new version.
|
| 140 |
+
|
| 141 |
+
### Upgrade PyTorch
|
| 142 |
+
|
| 143 |
+
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
|
| 144 |
+
|
| 145 |
+
## Credits
|
| 146 |
+
|
| 147 |
+
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
|
| 148 |
+
|
| 149 |
+
The LoRA expansion to Conv2d 3x3 was initially released by cloneofsimo and its effectiveness was demonstrated at [LoCon](https://github.com/KohakuBlueleaf/LoCon) by KohakuBlueleaf. Thank you so much KohakuBlueleaf!
|
| 150 |
+
|
| 151 |
+
## License
|
| 152 |
+
|
| 153 |
+
The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's and LoCon), however portions of the project are available under separate license terms:
|
| 154 |
+
|
| 155 |
+
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
| 156 |
+
|
| 157 |
+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
| 158 |
+
|
| 159 |
+
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
## Change History
|
| 163 |
+
|
| 164 |
+
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
|
| 165 |
+
|
| 166 |
+
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
|
| 167 |
+
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
|
| 168 |
+
|
| 169 |
+
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
| 170 |
+
|
| 171 |
+
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
| 172 |
+
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
|
| 173 |
+
- If you encounter any issues, please report them.
|
| 174 |
+
|
| 175 |
+
- The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
|
| 176 |
+
- The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
|
| 177 |
+
- The following changes are included.
|
| 178 |
+
|
| 179 |
+
#### Changes
|
| 180 |
+
|
| 181 |
+
- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
|
| 182 |
+
- Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
|
| 183 |
+
|
| 184 |
+
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
|
| 185 |
+
|
| 186 |
+
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
|
| 187 |
+
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
|
| 188 |
+
|
| 189 |
+
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
|
| 190 |
+
|
| 191 |
+
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
|
| 192 |
+
|
| 193 |
+
- Improvements in OFT (Orthogonal Finetuning) Implementation
|
| 194 |
+
1. Optimization of Calculation Order:
|
| 195 |
+
- Changed the calculation order in the forward method from (Wx)R to W(xR).
|
| 196 |
+
- This has improved computational efficiency and processing speed.
|
| 197 |
+
2. Correction of Bias Application:
|
| 198 |
+
- In the previous implementation, R was incorrectly applied to the bias.
|
| 199 |
+
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
|
| 200 |
+
3. Efficiency Enhancement in Matrix Operations:
|
| 201 |
+
- Introduced einsum in both the forward and merge_to methods.
|
| 202 |
+
- This has optimized matrix operations, resulting in further speed improvements.
|
| 203 |
+
4. Proper Handling of Data Types:
|
| 204 |
+
- Improved to use torch.float32 during calculations and convert results back to the original data type.
|
| 205 |
+
- This maintains precision while ensuring compatibility with the original model.
|
| 206 |
+
5. Unified Processing for Conv2d and Linear Layers:
|
| 207 |
+
- Implemented a consistent method for applying OFT to both layer types.
|
| 208 |
+
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
|
| 209 |
+
|
| 210 |
+
- Additional Information
|
| 211 |
+
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
|
| 212 |
+
|
| 213 |
+
* Performance Improvement: Training speed has been improved by approximately 30%.
|
| 214 |
+
|
| 215 |
+
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
|
| 216 |
+
|
| 217 |
+
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
|
| 218 |
+
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
|
| 219 |
+
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
|
| 220 |
+
|
| 221 |
+
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
|
| 222 |
+
|
| 223 |
+
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
|
| 224 |
+
|
| 225 |
+
- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
|
| 226 |
+
|
| 227 |
+
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
|
| 228 |
+
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
|
| 229 |
+
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
|
| 230 |
+
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
|
| 231 |
+
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
|
| 232 |
+
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
|
| 233 |
+
- Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
|
| 234 |
+
|
| 235 |
+
- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
| 236 |
+
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
|
| 237 |
+
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
|
| 238 |
+
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
|
| 239 |
+
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
|
| 240 |
+
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
|
| 241 |
+
|
| 242 |
+
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
|
| 243 |
+
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
|
| 244 |
+
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
|
| 245 |
+
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
|
| 246 |
+
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
|
| 247 |
+
- `network_module` `networks.lora` and `networks.dylora` are available.
|
| 248 |
+
|
| 249 |
+
- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru!
|
| 250 |
+
- The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file.
|
| 251 |
+
- See [About masked loss](./docs/masked_loss_README.md) for details.
|
| 252 |
+
|
| 253 |
+
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
| 254 |
+
- Specify the learning rate and dim (rank) for each block.
|
| 255 |
+
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
|
| 256 |
+
|
| 257 |
+
- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath!
|
| 258 |
+
- The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended.
|
| 259 |
+
- When specifying from the command line, use `=` like `--learning_rate=-1e-7`.
|
| 260 |
+
|
| 261 |
+
- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
|
| 262 |
+
- Some settings, such as API keys and directory specifications, are not output due to security issues.
|
| 263 |
+
|
| 264 |
+
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
|
| 265 |
+
|
| 266 |
+
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
|
| 267 |
+
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
|
| 268 |
+
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
|
| 269 |
+
- If `--resume` is specified, the step saved in the state is used.
|
| 270 |
+
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
|
| 271 |
+
|
| 272 |
+
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
|
| 273 |
+
- It seems that the model file loading is faster in the WSL environment etc.
|
| 274 |
+
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
|
| 275 |
+
|
| 276 |
+
- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath!
|
| 277 |
+
|
| 278 |
+
- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291](
|
| 279 |
+
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
| 280 |
+
https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
|
| 281 |
+
|
| 282 |
+
- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO!
|
| 283 |
+
|
| 284 |
+
- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th!
|
| 285 |
+
|
| 286 |
+
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
| 287 |
+
|
| 288 |
+
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
|
| 289 |
+
|
| 290 |
+
#### 変更点
|
| 291 |
+
|
| 292 |
+
- devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきま���。
|
| 293 |
+
- マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください。
|
| 294 |
+
- 以下の変更が含まれます。
|
| 295 |
+
|
| 296 |
+
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
|
| 297 |
+
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
|
| 298 |
+
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
|
| 299 |
+
- mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
|
| 300 |
+
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
|
| 301 |
+
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
|
| 302 |
+
- 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
|
| 303 |
+
|
| 304 |
+
- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
|
| 305 |
+
- Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
|
| 306 |
+
- `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
|
| 307 |
+
- 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。
|
| 308 |
+
- `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
|
| 309 |
+
- 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
|
| 310 |
+
|
| 311 |
+
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
|
| 312 |
+
- LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
|
| 313 |
+
- `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
|
| 314 |
+
- `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
|
| 315 |
+
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
|
| 316 |
+
- `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
|
| 317 |
+
|
| 318 |
+
- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。
|
| 319 |
+
- 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。
|
| 320 |
+
- 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。
|
| 321 |
+
|
| 322 |
+
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
|
| 323 |
+
- ブロックごとに学習率および dim (rank) を指定することができます。
|
| 324 |
+
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
|
| 325 |
+
|
| 326 |
+
- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。
|
| 327 |
+
- 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。
|
| 328 |
+
- コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。
|
| 329 |
+
|
| 330 |
+
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
|
| 331 |
+
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
|
| 332 |
+
|
| 333 |
+
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
|
| 334 |
+
|
| 335 |
+
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
|
| 336 |
+
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
|
| 337 |
+
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
|
| 338 |
+
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
|
| 339 |
+
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
|
| 340 |
+
|
| 341 |
+
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
|
| 342 |
+
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
|
| 343 |
+
- `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
|
| 344 |
+
|
| 345 |
+
- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。
|
| 346 |
+
|
| 347 |
+
- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291](
|
| 348 |
+
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
|
| 349 |
+
https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。
|
| 350 |
+
|
| 351 |
+
- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。
|
| 352 |
+
|
| 353 |
+
- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。
|
| 354 |
+
|
| 355 |
+
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
|
| 356 |
+
|
| 357 |
+
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
### Oct 27, 2024 / 2024-10-27:
|
| 361 |
+
|
| 362 |
+
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
|
| 363 |
+
- This will be included in the next release.
|
| 364 |
+
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
|
| 365 |
+
- これは次回リリースに含ま���ます。
|
| 366 |
+
|
| 367 |
+
### Oct 26, 2024 / 2024-10-26:
|
| 368 |
+
|
| 369 |
+
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
|
| 370 |
+
- It will be included in the next release.
|
| 371 |
+
|
| 372 |
+
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
|
| 373 |
+
- 以上は次回リリースに含まれます。
|
| 374 |
+
|
| 375 |
+
### Sep 13, 2024 / 2024-09-13:
|
| 376 |
+
|
| 377 |
+
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
|
| 378 |
+
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
|
| 379 |
+
- `sdxl_merge_lora.py` also supports LBW.
|
| 380 |
+
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
|
| 381 |
+
- These will be included in the next release.
|
| 382 |
+
|
| 383 |
+
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
|
| 384 |
+
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
|
| 385 |
+
- `sdxl_merge_lora.py` でも LBW がサポートされました。
|
| 386 |
+
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
|
| 387 |
+
- 以上は次回リリースに含まれます。
|
| 388 |
+
|
| 389 |
+
### Jun 23, 2024 / 2024-06-23:
|
| 390 |
+
|
| 391 |
+
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
|
| 392 |
+
|
| 393 |
+
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
|
| 394 |
+
|
| 395 |
+
### Apr 7, 2024 / 2024-04-07: v0.8.7
|
| 396 |
+
|
| 397 |
+
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
|
| 398 |
+
|
| 399 |
+
- Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
|
| 400 |
+
|
| 401 |
+
### Apr 7, 2024 / 2024-04-07: v0.8.6
|
| 402 |
+
|
| 403 |
+
#### Highlights
|
| 404 |
+
|
| 405 |
+
- The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
| 406 |
+
- Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
|
| 407 |
+
- `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
|
| 408 |
+
- `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
|
| 409 |
+
- Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
|
| 410 |
+
- When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
|
| 411 |
+
- A warning is displayed at the start of training if such information is included in the command line.
|
| 412 |
+
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
|
| 413 |
+
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
|
| 414 |
+
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
|
| 415 |
+
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
|
| 416 |
+
|
| 417 |
+
#### Training scripts
|
| 418 |
+
|
| 419 |
+
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
|
| 420 |
+
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
|
| 421 |
+
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
|
| 422 |
+
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
|
| 423 |
+
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
|
| 424 |
+
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
|
| 425 |
+
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
|
| 426 |
+
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
|
| 427 |
+
|
| 428 |
+
#### Dataset settings
|
| 429 |
+
|
| 430 |
+
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
|
| 431 |
+
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
|
| 432 |
+
- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
|
| 433 |
+
- Some features are added to the dataset subset settings.
|
| 434 |
+
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
|
| 435 |
+
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
|
| 436 |
+
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
|
| 437 |
+
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
|
| 438 |
+
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
|
| 439 |
+
- See [Dataset config](./docs/config_README-en.md) for details.
|
| 440 |
+
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
|
| 441 |
+
|
| 442 |
+
#### Image tagging
|
| 443 |
+
|
| 444 |
+
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
|
| 445 |
+
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
|
| 446 |
+
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
|
| 447 |
+
- Some options are added to `tag_image_by_wd14_tagger.py`.
|
| 448 |
+
- Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
|
| 449 |
+
- Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
|
| 450 |
+
- Output character tags first `--character_tags_first`
|
| 451 |
+
- Expand character tags and series `--character_tag_expand`
|
| 452 |
+
- Specify tags to output first `--always_first_tags`
|
| 453 |
+
- Replace tags `--tag_replacement`
|
| 454 |
+
- See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
|
| 455 |
+
- Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
|
| 456 |
+
|
| 457 |
+
#### About Masked loss
|
| 458 |
+
|
| 459 |
+
The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
|
| 460 |
+
|
| 461 |
+
The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
|
| 462 |
+
|
| 463 |
+
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
|
| 464 |
+
|
| 465 |
+
#### About Scheduled Huber Loss
|
| 466 |
+
|
| 467 |
+
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
|
| 468 |
+
|
| 469 |
+
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
|
| 470 |
+
|
| 471 |
+
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
|
| 472 |
+
|
| 473 |
+
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
|
| 474 |
+
|
| 475 |
+
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
|
| 476 |
+
|
| 477 |
+
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
|
| 478 |
+
|
| 479 |
+
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
|
| 480 |
+
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
|
| 481 |
+
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
|
| 482 |
+
|
| 483 |
+
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
| 484 |
+
|
| 485 |
+
#### 主要な変更点
|
| 486 |
+
|
| 487 |
+
- 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
|
| 488 |
+
- 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
|
| 489 |
+
- `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
|
| 490 |
+
- `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
|
| 491 |
+
- また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
|
| 492 |
+
- wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
|
| 493 |
+
- このような場合には学習開始時に警告が表示されます。
|
| 494 |
+
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
|
| 495 |
+
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
|
| 496 |
+
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
|
| 497 |
+
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
|
| 498 |
+
|
| 499 |
+
#### 学習スクリプト
|
| 500 |
+
|
| 501 |
+
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
|
| 502 |
+
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
|
| 503 |
+
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
|
| 504 |
+
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
|
| 505 |
+
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
|
| 506 |
+
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
|
| 507 |
+
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
|
| 508 |
+
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
|
| 509 |
+
|
| 510 |
+
#### データセット設定
|
| 511 |
+
|
| 512 |
+
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
|
| 513 |
+
- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
|
| 514 |
+
- データセットのサブセット設定にいくつかの機能を追加しました。
|
| 515 |
+
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
|
| 516 |
+
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
|
| 517 |
+
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
|
| 518 |
+
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
|
| 519 |
+
- 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
|
| 520 |
+
- DreamBooth 方式の DataSet で画像情報(サイズ、キャプション)をキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
|
| 521 |
+
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
|
| 522 |
+
|
| 523 |
+
#### 画像のタグ付け
|
| 524 |
+
|
| 525 |
+
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
|
| 526 |
+
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
|
| 527 |
+
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデル��ァイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
|
| 528 |
+
- `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
|
| 529 |
+
- 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
|
| 530 |
+
- レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
|
| 531 |
+
- キャラクタタグを最初に出力する `--character_tags_first`
|
| 532 |
+
- キャラクタタグとシリーズを展開する `--character_tag_expand`
|
| 533 |
+
- 常に最初に出力するタグを指定する `--always_first_tags`
|
| 534 |
+
- タグを置換する `--tag_replacement`
|
| 535 |
+
- 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
|
| 536 |
+
- `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
|
| 537 |
+
|
| 538 |
+
#### マスクロスについて
|
| 539 |
+
|
| 540 |
+
各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
|
| 541 |
+
|
| 542 |
+
機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
|
| 543 |
+
|
| 544 |
+
マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
|
| 545 |
+
|
| 546 |
+
#### Scheduled Huber Loss について
|
| 547 |
+
|
| 548 |
+
各学習スクリプトに、学習データ中の異常値や外れ値(data corruption)への耐性を高めるための手法、Scheduled Huber Lossが導入されました。
|
| 549 |
+
|
| 550 |
+
従来のMSE(L2)損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
|
| 551 |
+
|
| 552 |
+
この手法ではHuber損失関数の適用を工夫し、学習の初期段階(ノイズが大きい場合)ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
|
| 553 |
+
|
| 554 |
+
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
|
| 555 |
+
|
| 556 |
+
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類(Huber, smooth L1, MSE)とスケジューリング方法(exponential, constant, SNR)を選択できます。これによりデータセットに応じた最適化が可能になります。
|
| 557 |
+
|
| 558 |
+
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
|
| 559 |
+
|
| 560 |
+
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
|
| 561 |
+
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
|
| 562 |
+
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
|
| 563 |
+
|
| 564 |
+
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
|
| 565 |
+
|
| 566 |
+
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
| 567 |
+
|
| 568 |
+
## Additional Information
|
| 569 |
+
|
| 570 |
+
### Naming of LoRA
|
| 571 |
+
|
| 572 |
+
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
|
| 573 |
+
|
| 574 |
+
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
|
| 575 |
+
|
| 576 |
+
LoRA for Linear layers and Conv2d layers with 1x1 kernel
|
| 577 |
+
|
| 578 |
+
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
|
| 579 |
+
|
| 580 |
+
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
|
| 581 |
+
|
| 582 |
+
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
|
| 583 |
+
<!--
|
| 584 |
+
LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
| 585 |
+
|
| 586 |
+
To use LoRA-C3Lier with Web UI, please use our extension.
|
| 587 |
+
-->
|
| 588 |
+
|
| 589 |
+
### Sample image generation during training
|
| 590 |
+
A prompt file might look like this, for example
|
| 591 |
+
|
| 592 |
+
```
|
| 593 |
+
# prompt 1
|
| 594 |
+
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
| 595 |
+
|
| 596 |
+
# prompt 2
|
| 597 |
+
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
| 598 |
+
```
|
| 599 |
+
|
| 600 |
+
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
| 601 |
+
|
| 602 |
+
* `--n` Negative prompt up to the next option.
|
| 603 |
+
* `--w` Specifies the width of the generated image.
|
| 604 |
+
* `--h` Specifies the height of the generated image.
|
| 605 |
+
* `--d` Specifies the seed of the generated image.
|
| 606 |
+
* `--l` Specifies the CFG scale of the generated image.
|
| 607 |
+
* `--s` Specifies the number of steps in the generation.
|
| 608 |
+
|
| 609 |
+
The prompt weighting such as `( )` and `[ ]` are working.
|
XTI_hijack.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from library.device_utils import init_ipex
|
| 3 |
+
init_ipex()
|
| 4 |
+
|
| 5 |
+
from typing import Union, List, Optional, Dict, Any, Tuple
|
| 6 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 7 |
+
|
| 8 |
+
from library.original_unet import SampleOutput
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def unet_forward_XTI(
|
| 12 |
+
self,
|
| 13 |
+
sample: torch.FloatTensor,
|
| 14 |
+
timestep: Union[torch.Tensor, float, int],
|
| 15 |
+
encoder_hidden_states: torch.Tensor,
|
| 16 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 17 |
+
return_dict: bool = True,
|
| 18 |
+
) -> Union[Dict, Tuple]:
|
| 19 |
+
r"""
|
| 20 |
+
Args:
|
| 21 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
| 22 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
| 23 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
| 24 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 25 |
+
Whether or not to return a dict instead of a plain tuple.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
`SampleOutput` or `tuple`:
|
| 29 |
+
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
| 30 |
+
"""
|
| 31 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 32 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 33 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 34 |
+
# on the fly if necessary.
|
| 35 |
+
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
| 36 |
+
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
| 37 |
+
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
| 38 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 39 |
+
|
| 40 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 41 |
+
# 64で割り切れないときはupsamplerにサイズを伝える
|
| 42 |
+
forward_upsample_size = False
|
| 43 |
+
upsample_size = None
|
| 44 |
+
|
| 45 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 46 |
+
# logger.info("Forward upsample size to force interpolation output size.")
|
| 47 |
+
forward_upsample_size = True
|
| 48 |
+
|
| 49 |
+
# 1. time
|
| 50 |
+
timesteps = timestep
|
| 51 |
+
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
| 52 |
+
|
| 53 |
+
t_emb = self.time_proj(timesteps)
|
| 54 |
+
|
| 55 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 56 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 57 |
+
# there might be better ways to encapsulate this.
|
| 58 |
+
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
| 59 |
+
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
| 60 |
+
# time_projでキャストしておけばいいんじゃね?
|
| 61 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
| 62 |
+
emb = self.time_embedding(t_emb)
|
| 63 |
+
|
| 64 |
+
# 2. pre-process
|
| 65 |
+
sample = self.conv_in(sample)
|
| 66 |
+
|
| 67 |
+
# 3. down
|
| 68 |
+
down_block_res_samples = (sample,)
|
| 69 |
+
down_i = 0
|
| 70 |
+
for downsample_block in self.down_blocks:
|
| 71 |
+
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
| 72 |
+
# まあこちらのほうがわかりやすいかもしれない
|
| 73 |
+
if downsample_block.has_cross_attention:
|
| 74 |
+
sample, res_samples = downsample_block(
|
| 75 |
+
hidden_states=sample,
|
| 76 |
+
temb=emb,
|
| 77 |
+
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
|
| 78 |
+
)
|
| 79 |
+
down_i += 2
|
| 80 |
+
else:
|
| 81 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 82 |
+
|
| 83 |
+
down_block_res_samples += res_samples
|
| 84 |
+
|
| 85 |
+
# 4. mid
|
| 86 |
+
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
|
| 87 |
+
|
| 88 |
+
# 5. up
|
| 89 |
+
up_i = 7
|
| 90 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 91 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 92 |
+
|
| 93 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 94 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
| 95 |
+
|
| 96 |
+
# if we have not reached the final block and need to forward the upsample size, we do it here
|
| 97 |
+
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
| 98 |
+
if not is_final_block and forward_upsample_size:
|
| 99 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 100 |
+
|
| 101 |
+
if upsample_block.has_cross_attention:
|
| 102 |
+
sample = upsample_block(
|
| 103 |
+
hidden_states=sample,
|
| 104 |
+
temb=emb,
|
| 105 |
+
res_hidden_states_tuple=res_samples,
|
| 106 |
+
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
|
| 107 |
+
upsample_size=upsample_size,
|
| 108 |
+
)
|
| 109 |
+
up_i += 3
|
| 110 |
+
else:
|
| 111 |
+
sample = upsample_block(
|
| 112 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# 6. post-process
|
| 116 |
+
sample = self.conv_norm_out(sample)
|
| 117 |
+
sample = self.conv_act(sample)
|
| 118 |
+
sample = self.conv_out(sample)
|
| 119 |
+
|
| 120 |
+
if not return_dict:
|
| 121 |
+
return (sample,)
|
| 122 |
+
|
| 123 |
+
return SampleOutput(sample=sample)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def downblock_forward_XTI(
|
| 127 |
+
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
| 128 |
+
):
|
| 129 |
+
output_states = ()
|
| 130 |
+
i = 0
|
| 131 |
+
|
| 132 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 133 |
+
if self.training and self.gradient_checkpointing:
|
| 134 |
+
|
| 135 |
+
def create_custom_forward(module, return_dict=None):
|
| 136 |
+
def custom_forward(*inputs):
|
| 137 |
+
if return_dict is not None:
|
| 138 |
+
return module(*inputs, return_dict=return_dict)
|
| 139 |
+
else:
|
| 140 |
+
return module(*inputs)
|
| 141 |
+
|
| 142 |
+
return custom_forward
|
| 143 |
+
|
| 144 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 145 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 146 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
| 147 |
+
)[0]
|
| 148 |
+
else:
|
| 149 |
+
hidden_states = resnet(hidden_states, temb)
|
| 150 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
| 151 |
+
|
| 152 |
+
output_states += (hidden_states,)
|
| 153 |
+
i += 1
|
| 154 |
+
|
| 155 |
+
if self.downsamplers is not None:
|
| 156 |
+
for downsampler in self.downsamplers:
|
| 157 |
+
hidden_states = downsampler(hidden_states)
|
| 158 |
+
|
| 159 |
+
output_states += (hidden_states,)
|
| 160 |
+
|
| 161 |
+
return hidden_states, output_states
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def upblock_forward_XTI(
|
| 165 |
+
self,
|
| 166 |
+
hidden_states,
|
| 167 |
+
res_hidden_states_tuple,
|
| 168 |
+
temb=None,
|
| 169 |
+
encoder_hidden_states=None,
|
| 170 |
+
upsample_size=None,
|
| 171 |
+
):
|
| 172 |
+
i = 0
|
| 173 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 174 |
+
# pop res hidden states
|
| 175 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 176 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 177 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 178 |
+
|
| 179 |
+
if self.training and self.gradient_checkpointing:
|
| 180 |
+
|
| 181 |
+
def create_custom_forward(module, return_dict=None):
|
| 182 |
+
def custom_forward(*inputs):
|
| 183 |
+
if return_dict is not None:
|
| 184 |
+
return module(*inputs, return_dict=return_dict)
|
| 185 |
+
else:
|
| 186 |
+
return module(*inputs)
|
| 187 |
+
|
| 188 |
+
return custom_forward
|
| 189 |
+
|
| 190 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
| 191 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 192 |
+
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
|
| 193 |
+
)[0]
|
| 194 |
+
else:
|
| 195 |
+
hidden_states = resnet(hidden_states, temb)
|
| 196 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
|
| 197 |
+
|
| 198 |
+
i += 1
|
| 199 |
+
|
| 200 |
+
if self.upsamplers is not None:
|
| 201 |
+
for upsampler in self.upsamplers:
|
| 202 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 203 |
+
|
| 204 |
+
return hidden_states
|
_typos.toml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Files for typos
|
| 2 |
+
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
| 3 |
+
|
| 4 |
+
[default.extend-identifiers]
|
| 5 |
+
ddPn08="ddPn08"
|
| 6 |
+
|
| 7 |
+
[default.extend-words]
|
| 8 |
+
NIN="NIN"
|
| 9 |
+
parms="parms"
|
| 10 |
+
nin="nin"
|
| 11 |
+
extention="extention" # Intentionally left
|
| 12 |
+
nd="nd"
|
| 13 |
+
shs="shs"
|
| 14 |
+
sts="sts"
|
| 15 |
+
scs="scs"
|
| 16 |
+
cpc="cpc"
|
| 17 |
+
coc="coc"
|
| 18 |
+
cic="cic"
|
| 19 |
+
msm="msm"
|
| 20 |
+
usu="usu"
|
| 21 |
+
ici="ici"
|
| 22 |
+
lvl="lvl"
|
| 23 |
+
dii="dii"
|
| 24 |
+
muk="muk"
|
| 25 |
+
ori="ori"
|
| 26 |
+
hru="hru"
|
| 27 |
+
rik="rik"
|
| 28 |
+
koo="koo"
|
| 29 |
+
yos="yos"
|
| 30 |
+
wn="wn"
|
| 31 |
+
hime="hime"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
[files]
|
| 35 |
+
extend-exclude = ["_typos.toml", "venv"]
|
app.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LoRA Trainer Funcional para Hugging Face
|
| 4 |
+
Baseado no kohya-ss sd-scripts
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import json
|
| 11 |
+
import subprocess
|
| 12 |
+
import shutil
|
| 13 |
+
import zipfile
|
| 14 |
+
import tempfile
|
| 15 |
+
import toml
|
| 16 |
+
import logging
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional, Tuple, List, Dict, Any
|
| 19 |
+
import time
|
| 20 |
+
import threading
|
| 21 |
+
import queue
|
| 22 |
+
|
| 23 |
+
# Adicionar o diretório sd-scripts ao path
|
| 24 |
+
sys.path.insert(0, str(Path(__file__).parent / "sd-scripts"))
|
| 25 |
+
|
| 26 |
+
# Configurar logging
|
| 27 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
class LoRATrainerHF:
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.base_dir = Path("/tmp/lora_training")
|
| 33 |
+
self.base_dir.mkdir(exist_ok=True)
|
| 34 |
+
|
| 35 |
+
self.models_dir = self.base_dir / "models"
|
| 36 |
+
self.models_dir.mkdir(exist_ok=True)
|
| 37 |
+
|
| 38 |
+
self.projects_dir = self.base_dir / "projects"
|
| 39 |
+
self.projects_dir.mkdir(exist_ok=True)
|
| 40 |
+
|
| 41 |
+
self.sd_scripts_dir = Path(__file__).parent / "sd-scripts"
|
| 42 |
+
|
| 43 |
+
# URLs dos modelos
|
| 44 |
+
self.model_urls = {
|
| 45 |
+
"Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors",
|
| 46 |
+
"AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt",
|
| 47 |
+
"Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors",
|
| 48 |
+
"Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt"
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
self.training_process = None
|
| 52 |
+
self.training_output_queue = queue.Queue()
|
| 53 |
+
|
| 54 |
+
def install_dependencies(self) -> str:
|
| 55 |
+
"""Instala as dependências necessárias"""
|
| 56 |
+
try:
|
| 57 |
+
logger.info("Instalando dependências...")
|
| 58 |
+
|
| 59 |
+
# Lista de pacotes necessários
|
| 60 |
+
packages = [
|
| 61 |
+
"torch>=2.0.0",
|
| 62 |
+
"torchvision>=0.15.0",
|
| 63 |
+
"diffusers>=0.21.0",
|
| 64 |
+
"transformers>=4.25.0",
|
| 65 |
+
"accelerate>=0.20.0",
|
| 66 |
+
"safetensors>=0.3.0",
|
| 67 |
+
"huggingface-hub>=0.16.0",
|
| 68 |
+
"xformers>=0.0.20",
|
| 69 |
+
"bitsandbytes>=0.41.0",
|
| 70 |
+
"opencv-python>=4.7.0",
|
| 71 |
+
"Pillow>=9.0.0",
|
| 72 |
+
"numpy>=1.21.0",
|
| 73 |
+
"tqdm>=4.64.0",
|
| 74 |
+
"toml>=0.10.0",
|
| 75 |
+
"tensorboard>=2.13.0",
|
| 76 |
+
"wandb>=0.15.0",
|
| 77 |
+
"scipy>=1.9.0",
|
| 78 |
+
"matplotlib>=3.5.0",
|
| 79 |
+
"datasets>=2.14.0",
|
| 80 |
+
"peft>=0.5.0",
|
| 81 |
+
"omegaconf>=2.3.0"
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# Instalar pacotes
|
| 85 |
+
for package in packages:
|
| 86 |
+
try:
|
| 87 |
+
subprocess.run([
|
| 88 |
+
sys.executable, "-m", "pip", "install", package, "--quiet"
|
| 89 |
+
], check=True, capture_output=True, text=True)
|
| 90 |
+
logger.info(f"✓ {package} instalado")
|
| 91 |
+
except subprocess.CalledProcessError as e:
|
| 92 |
+
logger.warning(f"⚠ Erro ao instalar {package}: {e}")
|
| 93 |
+
|
| 94 |
+
return "✅ Dependências instaladas com sucesso!"
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Erro ao instalar dependências: {e}")
|
| 98 |
+
return f"❌ Erro ao instalar dependências: {e}"
|
| 99 |
+
|
| 100 |
+
def download_model(self, model_choice: str, custom_url: str = "") -> str:
|
| 101 |
+
"""Download do modelo base"""
|
| 102 |
+
try:
|
| 103 |
+
if custom_url.strip():
|
| 104 |
+
model_url = custom_url.strip()
|
| 105 |
+
model_name = model_url.split("/")[-1]
|
| 106 |
+
else:
|
| 107 |
+
if model_choice not in self.model_urls:
|
| 108 |
+
return f"❌ Modelo '{model_choice}' não encontrado"
|
| 109 |
+
model_url = self.model_urls[model_choice]
|
| 110 |
+
model_name = model_url.split("/")[-1]
|
| 111 |
+
|
| 112 |
+
model_path = self.models_dir / model_name
|
| 113 |
+
|
| 114 |
+
if model_path.exists():
|
| 115 |
+
return f"✅ Modelo já existe: {model_name}"
|
| 116 |
+
|
| 117 |
+
logger.info(f"Baixando modelo: {model_url}")
|
| 118 |
+
|
| 119 |
+
# Download usando wget
|
| 120 |
+
result = subprocess.run([
|
| 121 |
+
"wget", "-O", str(model_path), model_url, "--progress=bar:force"
|
| 122 |
+
], capture_output=True, text=True)
|
| 123 |
+
|
| 124 |
+
if result.returncode == 0:
|
| 125 |
+
return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)"
|
| 126 |
+
else:
|
| 127 |
+
return f"❌ Erro no download: {result.stderr}"
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Erro ao baixar modelo: {e}")
|
| 131 |
+
return f"❌ Erro ao baixar modelo: {e}"
|
| 132 |
+
|
| 133 |
+
def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]:
|
| 134 |
+
"""Processa o dataset enviado"""
|
| 135 |
+
try:
|
| 136 |
+
if not dataset_zip:
|
| 137 |
+
return "❌ Nenhum dataset foi enviado", ""
|
| 138 |
+
|
| 139 |
+
if not project_name.strip():
|
| 140 |
+
return "❌ Nome do projeto é obrigatório", ""
|
| 141 |
+
|
| 142 |
+
project_name = project_name.strip().replace(" ", "_")
|
| 143 |
+
project_dir = self.projects_dir / project_name
|
| 144 |
+
project_dir.mkdir(exist_ok=True)
|
| 145 |
+
|
| 146 |
+
dataset_dir = project_dir / "dataset"
|
| 147 |
+
if dataset_dir.exists():
|
| 148 |
+
shutil.rmtree(dataset_dir)
|
| 149 |
+
dataset_dir.mkdir()
|
| 150 |
+
|
| 151 |
+
# Extrair ZIP
|
| 152 |
+
with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref:
|
| 153 |
+
zip_ref.extractall(dataset_dir)
|
| 154 |
+
|
| 155 |
+
# Analisar dataset
|
| 156 |
+
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
|
| 157 |
+
images = []
|
| 158 |
+
captions = []
|
| 159 |
+
|
| 160 |
+
for file_path in dataset_dir.rglob("*"):
|
| 161 |
+
if file_path.suffix.lower() in image_extensions:
|
| 162 |
+
images.append(file_path)
|
| 163 |
+
|
| 164 |
+
# Procurar caption
|
| 165 |
+
caption_path = file_path.with_suffix('.txt')
|
| 166 |
+
if caption_path.exists():
|
| 167 |
+
captions.append(caption_path)
|
| 168 |
+
|
| 169 |
+
info = f"✅ Dataset processado!\n"
|
| 170 |
+
info += f"📁 Projeto: {project_name}\n"
|
| 171 |
+
info += f"🖼️ Imagens: {len(images)}\n"
|
| 172 |
+
info += f"📝 Captions: {len(captions)}\n"
|
| 173 |
+
info += f"📂 Diretório: {dataset_dir}"
|
| 174 |
+
|
| 175 |
+
return info, str(dataset_dir)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.error(f"Erro ao processar dataset: {e}")
|
| 179 |
+
return f"❌ Erro ao processar dataset: {e}", ""
|
| 180 |
+
|
| 181 |
+
def create_training_config(self,
|
| 182 |
+
project_name: str,
|
| 183 |
+
dataset_dir: str,
|
| 184 |
+
model_choice: str,
|
| 185 |
+
custom_model_url: str,
|
| 186 |
+
resolution: int,
|
| 187 |
+
batch_size: int,
|
| 188 |
+
epochs: int,
|
| 189 |
+
learning_rate: float,
|
| 190 |
+
text_encoder_lr: float,
|
| 191 |
+
network_dim: int,
|
| 192 |
+
network_alpha: int,
|
| 193 |
+
lora_type: str,
|
| 194 |
+
optimizer: str,
|
| 195 |
+
scheduler: str,
|
| 196 |
+
flip_aug: bool,
|
| 197 |
+
shuffle_caption: bool,
|
| 198 |
+
keep_tokens: int,
|
| 199 |
+
clip_skip: int,
|
| 200 |
+
mixed_precision: str,
|
| 201 |
+
save_every_n_epochs: int,
|
| 202 |
+
max_train_steps: int) -> str:
|
| 203 |
+
"""Cria configuração de treinamento"""
|
| 204 |
+
try:
|
| 205 |
+
if not project_name.strip():
|
| 206 |
+
return "❌ Nome do projeto é obrigatório"
|
| 207 |
+
|
| 208 |
+
project_name = project_name.strip().replace(" ", "_")
|
| 209 |
+
project_dir = self.projects_dir / project_name
|
| 210 |
+
project_dir.mkdir(exist_ok=True)
|
| 211 |
+
|
| 212 |
+
output_dir = project_dir / "output"
|
| 213 |
+
output_dir.mkdir(exist_ok=True)
|
| 214 |
+
|
| 215 |
+
log_dir = project_dir / "logs"
|
| 216 |
+
log_dir.mkdir(exist_ok=True)
|
| 217 |
+
|
| 218 |
+
# Determinar modelo
|
| 219 |
+
if custom_model_url.strip():
|
| 220 |
+
model_name = custom_model_url.strip().split("/")[-1]
|
| 221 |
+
else:
|
| 222 |
+
model_name = self.model_urls[model_choice].split("/")[-1]
|
| 223 |
+
|
| 224 |
+
model_path = self.models_dir / model_name
|
| 225 |
+
|
| 226 |
+
if not model_path.exists():
|
| 227 |
+
return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro."
|
| 228 |
+
|
| 229 |
+
# Configuração do dataset
|
| 230 |
+
dataset_config = {
|
| 231 |
+
"general": {
|
| 232 |
+
"shuffle_caption": shuffle_caption,
|
| 233 |
+
"caption_extension": ".txt",
|
| 234 |
+
"keep_tokens": keep_tokens,
|
| 235 |
+
"flip_aug": flip_aug,
|
| 236 |
+
"color_aug": False,
|
| 237 |
+
"face_crop_aug_range": None,
|
| 238 |
+
"random_crop": False,
|
| 239 |
+
"debug_dataset": False
|
| 240 |
+
},
|
| 241 |
+
"datasets": [{
|
| 242 |
+
"resolution": resolution,
|
| 243 |
+
"batch_size": batch_size,
|
| 244 |
+
"subsets": [{
|
| 245 |
+
"image_dir": str(dataset_dir),
|
| 246 |
+
"num_repeats": 1
|
| 247 |
+
}]
|
| 248 |
+
}]
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
# Configuração de treinamento
|
| 252 |
+
training_config = {
|
| 253 |
+
"model_arguments": {
|
| 254 |
+
"pretrained_model_name_or_path": str(model_path),
|
| 255 |
+
"v2": False,
|
| 256 |
+
"v_parameterization": False,
|
| 257 |
+
"clip_skip": clip_skip
|
| 258 |
+
},
|
| 259 |
+
"dataset_arguments": {
|
| 260 |
+
"dataset_config": str(project_dir / "dataset_config.toml")
|
| 261 |
+
},
|
| 262 |
+
"training_arguments": {
|
| 263 |
+
"output_dir": str(output_dir),
|
| 264 |
+
"output_name": project_name,
|
| 265 |
+
"save_precision": "fp16",
|
| 266 |
+
"save_every_n_epochs": save_every_n_epochs,
|
| 267 |
+
"max_train_epochs": epochs if max_train_steps == 0 else None,
|
| 268 |
+
"max_train_steps": max_train_steps if max_train_steps > 0 else None,
|
| 269 |
+
"train_batch_size": batch_size,
|
| 270 |
+
"gradient_accumulation_steps": 1,
|
| 271 |
+
"learning_rate": learning_rate,
|
| 272 |
+
"text_encoder_lr": text_encoder_lr,
|
| 273 |
+
"lr_scheduler": scheduler,
|
| 274 |
+
"lr_warmup_steps": 0,
|
| 275 |
+
"optimizer_type": optimizer,
|
| 276 |
+
"mixed_precision": mixed_precision,
|
| 277 |
+
"save_model_as": "safetensors",
|
| 278 |
+
"seed": 42,
|
| 279 |
+
"max_data_loader_n_workers": 2,
|
| 280 |
+
"persistent_data_loader_workers": True,
|
| 281 |
+
"gradient_checkpointing": True,
|
| 282 |
+
"xformers": True,
|
| 283 |
+
"lowram": True,
|
| 284 |
+
"cache_latents": True,
|
| 285 |
+
"cache_latents_to_disk": True,
|
| 286 |
+
"logging_dir": str(log_dir),
|
| 287 |
+
"log_with": "tensorboard"
|
| 288 |
+
},
|
| 289 |
+
"network_arguments": {
|
| 290 |
+
"network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora",
|
| 291 |
+
"network_dim": network_dim,
|
| 292 |
+
"network_alpha": network_alpha,
|
| 293 |
+
"network_train_unet_only": False,
|
| 294 |
+
"network_train_text_encoder_only": False
|
| 295 |
+
}
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
# Adicionar argumentos específicos para LoCon
|
| 299 |
+
if lora_type == "LoCon":
|
| 300 |
+
training_config["network_arguments"]["network_module"] = "networks.lora"
|
| 301 |
+
training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2)
|
| 302 |
+
training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2)
|
| 303 |
+
|
| 304 |
+
# Salvar configurações
|
| 305 |
+
dataset_config_path = project_dir / "dataset_config.toml"
|
| 306 |
+
training_config_path = project_dir / "training_config.toml"
|
| 307 |
+
|
| 308 |
+
with open(dataset_config_path, 'w') as f:
|
| 309 |
+
toml.dump(dataset_config, f)
|
| 310 |
+
|
| 311 |
+
with open(training_config_path, 'w') as f:
|
| 312 |
+
toml.dump(training_config, f)
|
| 313 |
+
|
| 314 |
+
return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}"
|
| 315 |
+
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.error(f"Erro ao criar configuração: {e}")
|
| 318 |
+
return f"❌ Erro ao criar configuração: {e}"
|
| 319 |
+
|
| 320 |
+
def start_training(self, project_name: str) -> str:
|
| 321 |
+
"""Inicia o treinamento"""
|
| 322 |
+
try:
|
| 323 |
+
if not project_name.strip():
|
| 324 |
+
return "❌ Nome do projeto é obrigatório"
|
| 325 |
+
|
| 326 |
+
project_name = project_name.strip().replace(" ", "_")
|
| 327 |
+
project_dir = self.projects_dir / project_name
|
| 328 |
+
|
| 329 |
+
training_config_path = project_dir / "training_config.toml"
|
| 330 |
+
if not training_config_path.exists():
|
| 331 |
+
return "❌ Configuração não encontrada. Crie a configuração primeiro."
|
| 332 |
+
|
| 333 |
+
# Script de treinamento
|
| 334 |
+
train_script = self.sd_scripts_dir / "train_network.py"
|
| 335 |
+
if not train_script.exists():
|
| 336 |
+
return "❌ Script de treinamento não encontrado"
|
| 337 |
+
|
| 338 |
+
# Comando de treinamento
|
| 339 |
+
cmd = [
|
| 340 |
+
sys.executable,
|
| 341 |
+
str(train_script),
|
| 342 |
+
"--config_file", str(training_config_path)
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
logger.info(f"Iniciando treinamento: {' '.join(cmd)}")
|
| 346 |
+
|
| 347 |
+
# Executar em thread separada
|
| 348 |
+
def run_training():
|
| 349 |
+
try:
|
| 350 |
+
process = subprocess.Popen(
|
| 351 |
+
cmd,
|
| 352 |
+
stdout=subprocess.PIPE,
|
| 353 |
+
stderr=subprocess.STDOUT,
|
| 354 |
+
text=True,
|
| 355 |
+
bufsize=1,
|
| 356 |
+
universal_newlines=True,
|
| 357 |
+
cwd=str(self.sd_scripts_dir)
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
self.training_process = process
|
| 361 |
+
|
| 362 |
+
for line in process.stdout:
|
| 363 |
+
self.training_output_queue.put(line.strip())
|
| 364 |
+
logger.info(line.strip())
|
| 365 |
+
|
| 366 |
+
process.wait()
|
| 367 |
+
|
| 368 |
+
if process.returncode == 0:
|
| 369 |
+
self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!")
|
| 370 |
+
else:
|
| 371 |
+
self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})")
|
| 372 |
+
|
| 373 |
+
except Exception as e:
|
| 374 |
+
self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}")
|
| 375 |
+
finally:
|
| 376 |
+
self.training_process = None
|
| 377 |
+
|
| 378 |
+
# Iniciar thread
|
| 379 |
+
training_thread = threading.Thread(target=run_training)
|
| 380 |
+
training_thread.daemon = True
|
| 381 |
+
training_thread.start()
|
| 382 |
+
|
| 383 |
+
return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo."
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"Erro ao iniciar treinamento: {e}")
|
| 387 |
+
return f"❌ Erro ao iniciar treinamento: {e}"
|
| 388 |
+
|
| 389 |
+
def get_training_output(self) -> str:
|
| 390 |
+
"""Obtém output do treinamento"""
|
| 391 |
+
output_lines = []
|
| 392 |
+
try:
|
| 393 |
+
while not self.training_output_queue.empty():
|
| 394 |
+
line = self.training_output_queue.get_nowait()
|
| 395 |
+
output_lines.append(line)
|
| 396 |
+
except queue.Empty:
|
| 397 |
+
pass
|
| 398 |
+
|
| 399 |
+
if output_lines:
|
| 400 |
+
return "\n".join(output_lines)
|
| 401 |
+
elif self.training_process and self.training_process.poll() is None:
|
| 402 |
+
return "🔄 Treinamento em andamento..."
|
| 403 |
+
else:
|
| 404 |
+
return "⏸️ Nenhum treinamento ativo"
|
| 405 |
+
|
| 406 |
+
def stop_training(self) -> str:
|
| 407 |
+
"""Para o treinamento"""
|
| 408 |
+
try:
|
| 409 |
+
if self.training_process and self.training_process.poll() is None:
|
| 410 |
+
self.training_process.terminate()
|
| 411 |
+
self.training_process.wait(timeout=10)
|
| 412 |
+
return "⏹️ Treinamento interrompido"
|
| 413 |
+
else:
|
| 414 |
+
return "ℹ️ Nenhum treinamento ativo para parar"
|
| 415 |
+
except Exception as e:
|
| 416 |
+
return f"❌ Erro ao parar treinamento: {e}"
|
| 417 |
+
|
| 418 |
+
def list_output_files(self, project_name: str) -> List[str]:
|
| 419 |
+
"""Lista arquivos de saída"""
|
| 420 |
+
try:
|
| 421 |
+
if not project_name.strip():
|
| 422 |
+
return []
|
| 423 |
+
|
| 424 |
+
project_name = project_name.strip().replace(" ", "_")
|
| 425 |
+
project_dir = self.projects_dir / project_name
|
| 426 |
+
output_dir = project_dir / "output"
|
| 427 |
+
|
| 428 |
+
if not output_dir.exists():
|
| 429 |
+
return []
|
| 430 |
+
|
| 431 |
+
files = []
|
| 432 |
+
for file_path in output_dir.rglob("*.safetensors"):
|
| 433 |
+
size_mb = file_path.stat().st_size // (1024 * 1024)
|
| 434 |
+
files.append(f"{file_path.name} ({size_mb} MB)")
|
| 435 |
+
|
| 436 |
+
return sorted(files, reverse=True) # Mais recentes primeiro
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
logger.error(f"Erro ao listar arquivos: {e}")
|
| 440 |
+
return []
|
| 441 |
+
|
| 442 |
+
# Instância global
|
| 443 |
+
trainer = LoRATrainerHF()
|
| 444 |
+
|
| 445 |
+
def create_interface():
|
| 446 |
+
"""Cria a interface Gradio"""
|
| 447 |
+
|
| 448 |
+
with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface:
|
| 449 |
+
|
| 450 |
+
gr.Markdown("""
|
| 451 |
+
# 🎨 LoRA Trainer Funcional para Hugging Face
|
| 452 |
+
|
| 453 |
+
**Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!**
|
| 454 |
+
|
| 455 |
+
Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA.
|
| 456 |
+
""")
|
| 457 |
+
|
| 458 |
+
# Estado para armazenar informações
|
| 459 |
+
dataset_dir_state = gr.State("")
|
| 460 |
+
|
| 461 |
+
with gr.Tab("🔧 Instalação"):
|
| 462 |
+
gr.Markdown("### Primeiro, instale as dependências necessárias:")
|
| 463 |
+
install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg")
|
| 464 |
+
install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False)
|
| 465 |
+
|
| 466 |
+
install_btn.click(
|
| 467 |
+
fn=trainer.install_dependencies,
|
| 468 |
+
outputs=install_status
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
with gr.Tab("📁 Configuração do Projeto"):
|
| 472 |
+
with gr.Row():
|
| 473 |
+
project_name = gr.Textbox(
|
| 474 |
+
label="Nome do Projeto",
|
| 475 |
+
placeholder="meu_lora_anime",
|
| 476 |
+
info="Nome único para seu projeto (sem espaços especiais)"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
gr.Markdown("### 📥 Download do Modelo Base")
|
| 480 |
+
with gr.Row():
|
| 481 |
+
model_choice = gr.Dropdown(
|
| 482 |
+
choices=list(trainer.model_urls.keys()),
|
| 483 |
+
label="Modelo Base Pré-definido",
|
| 484 |
+
value="Anime (animefull-final-pruned)",
|
| 485 |
+
info="Escolha um modelo base ou use URL personalizada"
|
| 486 |
+
)
|
| 487 |
+
custom_model_url = gr.Textbox(
|
| 488 |
+
label="URL Personalizada (opcional)",
|
| 489 |
+
placeholder="https://huggingface.co/...",
|
| 490 |
+
info="URL direta para download de modelo personalizado"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
download_btn = gr.Button("📥 Baixar Modelo", variant="primary")
|
| 494 |
+
download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False)
|
| 495 |
+
|
| 496 |
+
gr.Markdown("### 📊 Upload do Dataset")
|
| 497 |
+
gr.Markdown("""
|
| 498 |
+
**Formato do Dataset:**
|
| 499 |
+
- Crie um arquivo ZIP contendo suas imagens
|
| 500 |
+
- Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições
|
| 501 |
+
- Exemplo: `imagem1.jpg` + `imagem1.txt`
|
| 502 |
+
""")
|
| 503 |
+
|
| 504 |
+
dataset_upload = gr.File(
|
| 505 |
+
label="Upload do Dataset (ZIP)",
|
| 506 |
+
file_types=[".zip"]
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
process_btn = gr.Button("📊 Processar Dataset", variant="primary")
|
| 510 |
+
dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False)
|
| 511 |
+
|
| 512 |
+
with gr.Tab("⚙️ Parâmetros de Treinamento"):
|
| 513 |
+
with gr.Row():
|
| 514 |
+
with gr.Column():
|
| 515 |
+
gr.Markdown("#### 🖼️ Configurações de Imagem")
|
| 516 |
+
resolution = gr.Slider(
|
| 517 |
+
minimum=512, maximum=1024, step=64, value=512,
|
| 518 |
+
label="Resolução",
|
| 519 |
+
info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)"
|
| 520 |
+
)
|
| 521 |
+
batch_size = gr.Slider(
|
| 522 |
+
minimum=1, maximum=8, step=1, value=1,
|
| 523 |
+
label="Batch Size",
|
| 524 |
+
info="Imagens por lote (aumente se tiver GPU potente)"
|
| 525 |
+
)
|
| 526 |
+
flip_aug = gr.Checkbox(
|
| 527 |
+
label="Flip Augmentation",
|
| 528 |
+
info="Espelhar imagens para aumentar dataset"
|
| 529 |
+
)
|
| 530 |
+
shuffle_caption = gr.Checkbox(
|
| 531 |
+
value=True,
|
| 532 |
+
label="Shuffle Caption",
|
| 533 |
+
info="Embaralhar ordem das tags"
|
| 534 |
+
)
|
| 535 |
+
keep_tokens = gr.Slider(
|
| 536 |
+
minimum=0, maximum=5, step=1, value=1,
|
| 537 |
+
label="Keep Tokens",
|
| 538 |
+
info="Número de tokens iniciais que não serão embaralhados"
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
with gr.Column():
|
| 542 |
+
gr.Markdown("#### 🎯 Configurações de Treinamento")
|
| 543 |
+
epochs = gr.Slider(
|
| 544 |
+
minimum=1, maximum=100, step=1, value=10,
|
| 545 |
+
label="Épocas",
|
| 546 |
+
info="Número de épocas de treinamento"
|
| 547 |
+
)
|
| 548 |
+
max_train_steps = gr.Number(
|
| 549 |
+
value=0,
|
| 550 |
+
label="Max Train Steps (0 = usar épocas)",
|
| 551 |
+
info="Número máximo de steps (deixe 0 para usar épocas)"
|
| 552 |
+
)
|
| 553 |
+
save_every_n_epochs = gr.Slider(
|
| 554 |
+
minimum=1, maximum=10, step=1, value=1,
|
| 555 |
+
label="Salvar a cada N épocas",
|
| 556 |
+
info="Frequência de salvamento dos checkpoints"
|
| 557 |
+
)
|
| 558 |
+
mixed_precision = gr.Dropdown(
|
| 559 |
+
choices=["fp16", "bf16", "no"],
|
| 560 |
+
value="fp16",
|
| 561 |
+
label="Mixed Precision",
|
| 562 |
+
info="fp16 = mais rápido, bf16 = mais estável"
|
| 563 |
+
)
|
| 564 |
+
clip_skip = gr.Slider(
|
| 565 |
+
minimum=1, maximum=12, step=1, value=2,
|
| 566 |
+
label="CLIP Skip",
|
| 567 |
+
info="Camadas CLIP a pular (2 para anime, 1 para realista)"
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
with gr.Row():
|
| 571 |
+
with gr.Column():
|
| 572 |
+
gr.Markdown("#### 📚 Learning Rate")
|
| 573 |
+
learning_rate = gr.Number(
|
| 574 |
+
value=1e-4,
|
| 575 |
+
label="Learning Rate (UNet)",
|
| 576 |
+
info="Taxa de aprendizado principal"
|
| 577 |
+
)
|
| 578 |
+
text_encoder_lr = gr.Number(
|
| 579 |
+
value=5e-5,
|
| 580 |
+
label="Learning Rate (Text Encoder)",
|
| 581 |
+
info="Taxa de aprendizado do text encoder"
|
| 582 |
+
)
|
| 583 |
+
scheduler = gr.Dropdown(
|
| 584 |
+
choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"],
|
| 585 |
+
value="cosine_with_restarts",
|
| 586 |
+
label="LR Scheduler",
|
| 587 |
+
info="Algoritmo de ajuste da learning rate"
|
| 588 |
+
)
|
| 589 |
+
optimizer = gr.Dropdown(
|
| 590 |
+
choices=["AdamW8bit", "AdamW", "Lion", "SGD"],
|
| 591 |
+
value="AdamW8bit",
|
| 592 |
+
label="Otimizador",
|
| 593 |
+
info="AdamW8bit = menos memória"
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
with gr.Column():
|
| 597 |
+
gr.Markdown("#### 🧠 Arquitetura LoRA")
|
| 598 |
+
lora_type = gr.Radio(
|
| 599 |
+
choices=["LoRA", "LoCon"],
|
| 600 |
+
value="LoRA",
|
| 601 |
+
label="Tipo de LoRA",
|
| 602 |
+
info="LoRA = geral, LoCon = estilos artísticos"
|
| 603 |
+
)
|
| 604 |
+
network_dim = gr.Slider(
|
| 605 |
+
minimum=4, maximum=128, step=4, value=32,
|
| 606 |
+
label="Network Dimension",
|
| 607 |
+
info="Dimensão da rede (maior = mais detalhes, mais memória)"
|
| 608 |
+
)
|
| 609 |
+
network_alpha = gr.Slider(
|
| 610 |
+
minimum=1, maximum=128, step=1, value=16,
|
| 611 |
+
label="Network Alpha",
|
| 612 |
+
info="Controla a força do LoRA (geralmente dim/2)"
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
with gr.Tab("🚀 Treinamento"):
|
| 616 |
+
create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg")
|
| 617 |
+
config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False)
|
| 618 |
+
|
| 619 |
+
with gr.Row():
|
| 620 |
+
start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg")
|
| 621 |
+
stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop")
|
| 622 |
+
|
| 623 |
+
training_output = gr.Textbox(
|
| 624 |
+
label="Output do Treinamento",
|
| 625 |
+
lines=15,
|
| 626 |
+
interactive=False,
|
| 627 |
+
info="Acompanhe o progresso do treinamento em tempo real"
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
# Auto-refresh do output
|
| 631 |
+
def update_output():
|
| 632 |
+
return trainer.get_training_output()
|
| 633 |
+
|
| 634 |
+
with gr.Tab("📥 Download dos Resultados"):
|
| 635 |
+
refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary")
|
| 636 |
+
|
| 637 |
+
output_files = gr.Dropdown(
|
| 638 |
+
label="Arquivos LoRA Gerados",
|
| 639 |
+
choices=[],
|
| 640 |
+
info="Selecione um arquivo para download"
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento")
|
| 644 |
+
|
| 645 |
+
# Event handlers
|
| 646 |
+
download_btn.click(
|
| 647 |
+
fn=trainer.download_model,
|
| 648 |
+
inputs=[model_choice, custom_model_url],
|
| 649 |
+
outputs=download_status
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
process_btn.click(
|
| 653 |
+
fn=trainer.process_dataset,
|
| 654 |
+
inputs=[dataset_upload, project_name],
|
| 655 |
+
outputs=[dataset_status, dataset_dir_state]
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
create_config_btn.click(
|
| 659 |
+
fn=trainer.create_training_config,
|
| 660 |
+
inputs=[
|
| 661 |
+
project_name, dataset_dir_state, model_choice, custom_model_url,
|
| 662 |
+
resolution, batch_size, epochs, learning_rate, text_encoder_lr,
|
| 663 |
+
network_dim, network_alpha, lora_type, optimizer, scheduler,
|
| 664 |
+
flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision,
|
| 665 |
+
save_every_n_epochs, max_train_steps
|
| 666 |
+
],
|
| 667 |
+
outputs=config_status
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
start_training_btn.click(
|
| 671 |
+
fn=trainer.start_training,
|
| 672 |
+
inputs=project_name,
|
| 673 |
+
outputs=training_output
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
stop_training_btn.click(
|
| 677 |
+
fn=trainer.stop_training,
|
| 678 |
+
outputs=training_output
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
refresh_files_btn.click(
|
| 682 |
+
fn=trainer.list_output_files,
|
| 683 |
+
inputs=project_name,
|
| 684 |
+
outputs=output_files
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
return interface
|
| 688 |
+
|
| 689 |
+
if __name__ == "__main__":
|
| 690 |
+
print("🚀 Iniciando LoRA Trainer Funcional...")
|
| 691 |
+
interface = create_interface()
|
| 692 |
+
interface.launch(
|
| 693 |
+
server_name="0.0.0.0",
|
| 694 |
+
server_port=7860,
|
| 695 |
+
share=False,
|
| 696 |
+
show_error=True
|
| 697 |
+
)
|
| 698 |
+
|
fine_tune.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training with captions
|
| 2 |
+
# XXX dropped option: hypernetwork training
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import math
|
| 6 |
+
import os
|
| 7 |
+
from multiprocessing import Value
|
| 8 |
+
import toml
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from library import deepspeed_utils
|
| 14 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 15 |
+
|
| 16 |
+
init_ipex()
|
| 17 |
+
|
| 18 |
+
from accelerate.utils import set_seed
|
| 19 |
+
from diffusers import DDPMScheduler
|
| 20 |
+
|
| 21 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 22 |
+
|
| 23 |
+
setup_logging()
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
import library.train_util as train_util
|
| 29 |
+
import library.config_util as config_util
|
| 30 |
+
from library.config_util import (
|
| 31 |
+
ConfigSanitizer,
|
| 32 |
+
BlueprintGenerator,
|
| 33 |
+
)
|
| 34 |
+
import library.custom_train_functions as custom_train_functions
|
| 35 |
+
from library.custom_train_functions import (
|
| 36 |
+
apply_snr_weight,
|
| 37 |
+
get_weighted_text_embeddings,
|
| 38 |
+
prepare_scheduler_for_custom_training,
|
| 39 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 40 |
+
apply_debiased_estimation,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def train(args):
|
| 45 |
+
train_util.verify_training_args(args)
|
| 46 |
+
train_util.prepare_dataset_args(args, True)
|
| 47 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
| 48 |
+
setup_logging(args, reset=True)
|
| 49 |
+
|
| 50 |
+
cache_latents = args.cache_latents
|
| 51 |
+
|
| 52 |
+
if args.seed is not None:
|
| 53 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
| 54 |
+
|
| 55 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 56 |
+
|
| 57 |
+
# データセットを準備する
|
| 58 |
+
if args.dataset_class is None:
|
| 59 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
| 60 |
+
if args.dataset_config is not None:
|
| 61 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 62 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 63 |
+
ignored = ["train_data_dir", "in_json"]
|
| 64 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 65 |
+
logger.warning(
|
| 66 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 67 |
+
", ".join(ignored)
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
user_config = {
|
| 72 |
+
"datasets": [
|
| 73 |
+
{
|
| 74 |
+
"subsets": [
|
| 75 |
+
{
|
| 76 |
+
"image_dir": args.train_data_dir,
|
| 77 |
+
"metadata_file": args.in_json,
|
| 78 |
+
}
|
| 79 |
+
]
|
| 80 |
+
}
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 85 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 86 |
+
else:
|
| 87 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
| 88 |
+
|
| 89 |
+
current_epoch = Value("i", 0)
|
| 90 |
+
current_step = Value("i", 0)
|
| 91 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 92 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 93 |
+
|
| 94 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
| 95 |
+
|
| 96 |
+
if args.debug_dataset:
|
| 97 |
+
train_util.debug_dataset(train_dataset_group)
|
| 98 |
+
return
|
| 99 |
+
if len(train_dataset_group) == 0:
|
| 100 |
+
logger.error(
|
| 101 |
+
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
| 102 |
+
)
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
if cache_latents:
|
| 106 |
+
assert (
|
| 107 |
+
train_dataset_group.is_latent_cacheable()
|
| 108 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 109 |
+
|
| 110 |
+
# acceleratorを準備する
|
| 111 |
+
logger.info("prepare accelerator")
|
| 112 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 113 |
+
|
| 114 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 115 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 116 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 117 |
+
|
| 118 |
+
# モデルを読み込む
|
| 119 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 120 |
+
|
| 121 |
+
# verify load/save model formats
|
| 122 |
+
if load_stable_diffusion_format:
|
| 123 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
| 124 |
+
src_diffusers_model_path = None
|
| 125 |
+
else:
|
| 126 |
+
src_stable_diffusion_ckpt = None
|
| 127 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
| 128 |
+
|
| 129 |
+
if args.save_model_as is None:
|
| 130 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
| 131 |
+
use_safetensors = args.use_safetensors
|
| 132 |
+
else:
|
| 133 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
| 134 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
| 135 |
+
|
| 136 |
+
# Diffusers版のxformers使用フラグを設定する関数
|
| 137 |
+
def set_diffusers_xformers_flag(model, valid):
|
| 138 |
+
# model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
|
| 139 |
+
# pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
|
| 140 |
+
# U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
|
| 141 |
+
# 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
|
| 142 |
+
|
| 143 |
+
# Recursively walk through all the children.
|
| 144 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
| 145 |
+
# gets the message
|
| 146 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 147 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 148 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
| 149 |
+
|
| 150 |
+
for child in module.children():
|
| 151 |
+
fn_recursive_set_mem_eff(child)
|
| 152 |
+
|
| 153 |
+
fn_recursive_set_mem_eff(model)
|
| 154 |
+
|
| 155 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 156 |
+
if args.diffusers_xformers:
|
| 157 |
+
accelerator.print("Use xformers by Diffusers")
|
| 158 |
+
set_diffusers_xformers_flag(unet, True)
|
| 159 |
+
else:
|
| 160 |
+
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
|
| 161 |
+
accelerator.print("Disable Diffusers' xformers")
|
| 162 |
+
set_diffusers_xformers_flag(unet, False)
|
| 163 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 164 |
+
|
| 165 |
+
# 学習を準備する
|
| 166 |
+
if cache_latents:
|
| 167 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 168 |
+
vae.requires_grad_(False)
|
| 169 |
+
vae.eval()
|
| 170 |
+
with torch.no_grad():
|
| 171 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 172 |
+
vae.to("cpu")
|
| 173 |
+
clean_memory_on_device(accelerator.device)
|
| 174 |
+
|
| 175 |
+
accelerator.wait_for_everyone()
|
| 176 |
+
|
| 177 |
+
# 学習を準備する:モデルを適切な状態にする
|
| 178 |
+
training_models = []
|
| 179 |
+
if args.gradient_checkpointing:
|
| 180 |
+
unet.enable_gradient_checkpointing()
|
| 181 |
+
training_models.append(unet)
|
| 182 |
+
|
| 183 |
+
if args.train_text_encoder:
|
| 184 |
+
accelerator.print("enable text encoder training")
|
| 185 |
+
if args.gradient_checkpointing:
|
| 186 |
+
text_encoder.gradient_checkpointing_enable()
|
| 187 |
+
training_models.append(text_encoder)
|
| 188 |
+
else:
|
| 189 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 190 |
+
text_encoder.requires_grad_(False) # text encoderは学習しない
|
| 191 |
+
if args.gradient_checkpointing:
|
| 192 |
+
text_encoder.gradient_checkpointing_enable()
|
| 193 |
+
text_encoder.train() # required for gradient_checkpointing
|
| 194 |
+
else:
|
| 195 |
+
text_encoder.eval()
|
| 196 |
+
|
| 197 |
+
if not cache_latents:
|
| 198 |
+
vae.requires_grad_(False)
|
| 199 |
+
vae.eval()
|
| 200 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 201 |
+
|
| 202 |
+
for m in training_models:
|
| 203 |
+
m.requires_grad_(True)
|
| 204 |
+
|
| 205 |
+
trainable_params = []
|
| 206 |
+
if args.learning_rate_te is None or not args.train_text_encoder:
|
| 207 |
+
for m in training_models:
|
| 208 |
+
trainable_params.extend(m.parameters())
|
| 209 |
+
else:
|
| 210 |
+
trainable_params = [
|
| 211 |
+
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
| 212 |
+
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
# 学習に必要なクラスを準備する
|
| 216 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 217 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
| 218 |
+
|
| 219 |
+
# dataloaderを準備する
|
| 220 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 221 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 222 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 223 |
+
train_dataset_group,
|
| 224 |
+
batch_size=1,
|
| 225 |
+
shuffle=True,
|
| 226 |
+
collate_fn=collator,
|
| 227 |
+
num_workers=n_workers,
|
| 228 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# 学習ステップ数を計算する
|
| 232 |
+
if args.max_train_epochs is not None:
|
| 233 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 234 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 235 |
+
)
|
| 236 |
+
accelerator.print(
|
| 237 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# データセット側にも学習ステップを送信
|
| 241 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 242 |
+
|
| 243 |
+
# lr schedulerを用意する
|
| 244 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 245 |
+
|
| 246 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 247 |
+
if args.full_fp16:
|
| 248 |
+
assert (
|
| 249 |
+
args.mixed_precision == "fp16"
|
| 250 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 251 |
+
accelerator.print("enable full fp16 training.")
|
| 252 |
+
unet.to(weight_dtype)
|
| 253 |
+
text_encoder.to(weight_dtype)
|
| 254 |
+
|
| 255 |
+
if args.deepspeed:
|
| 256 |
+
if args.train_text_encoder:
|
| 257 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
| 258 |
+
else:
|
| 259 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
| 260 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 261 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
| 262 |
+
)
|
| 263 |
+
training_models = [ds_model]
|
| 264 |
+
else:
|
| 265 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 266 |
+
if args.train_text_encoder:
|
| 267 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 268 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
| 272 |
+
|
| 273 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 274 |
+
if args.full_fp16:
|
| 275 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 276 |
+
|
| 277 |
+
# resumeする
|
| 278 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 279 |
+
|
| 280 |
+
# epoch数を計算する
|
| 281 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 282 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 283 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 284 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 285 |
+
|
| 286 |
+
# 学習する
|
| 287 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 288 |
+
accelerator.print("running training / 学習開始")
|
| 289 |
+
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
| 290 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 291 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 292 |
+
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
| 293 |
+
accelerator.print(
|
| 294 |
+
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
| 295 |
+
)
|
| 296 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 297 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 298 |
+
|
| 299 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 300 |
+
global_step = 0
|
| 301 |
+
|
| 302 |
+
noise_scheduler = DDPMScheduler(
|
| 303 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 304 |
+
)
|
| 305 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 306 |
+
if args.zero_terminal_snr:
|
| 307 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 308 |
+
|
| 309 |
+
if accelerator.is_main_process:
|
| 310 |
+
init_kwargs = {}
|
| 311 |
+
if args.wandb_run_name:
|
| 312 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 313 |
+
if args.log_tracker_config is not None:
|
| 314 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 315 |
+
accelerator.init_trackers(
|
| 316 |
+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
| 317 |
+
config=train_util.get_sanitized_config_or_none(args),
|
| 318 |
+
init_kwargs=init_kwargs,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# For --sample_at_first
|
| 322 |
+
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 323 |
+
|
| 324 |
+
loss_recorder = train_util.LossRecorder()
|
| 325 |
+
for epoch in range(num_train_epochs):
|
| 326 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 327 |
+
current_epoch.value = epoch + 1
|
| 328 |
+
|
| 329 |
+
for m in training_models:
|
| 330 |
+
m.train()
|
| 331 |
+
|
| 332 |
+
for step, batch in enumerate(train_dataloader):
|
| 333 |
+
current_step.value = global_step
|
| 334 |
+
with accelerator.accumulate(*training_models):
|
| 335 |
+
with torch.no_grad():
|
| 336 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 337 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 338 |
+
else:
|
| 339 |
+
# latentに変換
|
| 340 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
|
| 341 |
+
latents = latents * 0.18215
|
| 342 |
+
b_size = latents.shape[0]
|
| 343 |
+
|
| 344 |
+
with torch.set_grad_enabled(args.train_text_encoder):
|
| 345 |
+
# Get the text embedding for conditioning
|
| 346 |
+
if args.weighted_captions:
|
| 347 |
+
encoder_hidden_states = get_weighted_text_embeddings(
|
| 348 |
+
tokenizer,
|
| 349 |
+
text_encoder,
|
| 350 |
+
batch["captions"],
|
| 351 |
+
accelerator.device,
|
| 352 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
| 353 |
+
clip_skip=args.clip_skip,
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 357 |
+
encoder_hidden_states = train_util.get_hidden_states(
|
| 358 |
+
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 362 |
+
# with noise offset and/or multires noise if specified
|
| 363 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
| 364 |
+
args, noise_scheduler, latents
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Predict the noise residual
|
| 368 |
+
with accelerator.autocast():
|
| 369 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 370 |
+
|
| 371 |
+
if args.v_parameterization:
|
| 372 |
+
# v-parameterization training
|
| 373 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 374 |
+
else:
|
| 375 |
+
target = noise
|
| 376 |
+
|
| 377 |
+
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
| 378 |
+
# do not mean over batch dimension for snr weight or scale v-pred loss
|
| 379 |
+
loss = train_util.conditional_loss(
|
| 380 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
| 381 |
+
)
|
| 382 |
+
loss = loss.mean([1, 2, 3])
|
| 383 |
+
|
| 384 |
+
if args.min_snr_gamma:
|
| 385 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 386 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 387 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 388 |
+
if args.debiased_estimation_loss:
|
| 389 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 390 |
+
|
| 391 |
+
loss = loss.mean() # mean over batch dimension
|
| 392 |
+
else:
|
| 393 |
+
loss = train_util.conditional_loss(
|
| 394 |
+
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
accelerator.backward(loss)
|
| 398 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 399 |
+
params_to_clip = []
|
| 400 |
+
for m in training_models:
|
| 401 |
+
params_to_clip.extend(m.parameters())
|
| 402 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 403 |
+
|
| 404 |
+
optimizer.step()
|
| 405 |
+
lr_scheduler.step()
|
| 406 |
+
optimizer.zero_grad(set_to_none=True)
|
| 407 |
+
|
| 408 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 409 |
+
if accelerator.sync_gradients:
|
| 410 |
+
progress_bar.update(1)
|
| 411 |
+
global_step += 1
|
| 412 |
+
|
| 413 |
+
train_util.sample_images(
|
| 414 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# 指定ステップごとにモデルを保存
|
| 418 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 419 |
+
accelerator.wait_for_everyone()
|
| 420 |
+
if accelerator.is_main_process:
|
| 421 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 422 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 423 |
+
args,
|
| 424 |
+
False,
|
| 425 |
+
accelerator,
|
| 426 |
+
src_path,
|
| 427 |
+
save_stable_diffusion_format,
|
| 428 |
+
use_safetensors,
|
| 429 |
+
save_dtype,
|
| 430 |
+
epoch,
|
| 431 |
+
num_train_epochs,
|
| 432 |
+
global_step,
|
| 433 |
+
accelerator.unwrap_model(text_encoder),
|
| 434 |
+
accelerator.unwrap_model(unet),
|
| 435 |
+
vae,
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
| 439 |
+
if args.logging_dir is not None:
|
| 440 |
+
logs = {"loss": current_loss}
|
| 441 |
+
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
| 442 |
+
accelerator.log(logs, step=global_step)
|
| 443 |
+
|
| 444 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 445 |
+
avr_loss: float = loss_recorder.moving_average
|
| 446 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 447 |
+
progress_bar.set_postfix(**logs)
|
| 448 |
+
|
| 449 |
+
if global_step >= args.max_train_steps:
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
if args.logging_dir is not None:
|
| 453 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 454 |
+
accelerator.log(logs, step=epoch + 1)
|
| 455 |
+
|
| 456 |
+
accelerator.wait_for_everyone()
|
| 457 |
+
|
| 458 |
+
if args.save_every_n_epochs is not None:
|
| 459 |
+
if accelerator.is_main_process:
|
| 460 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 461 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 462 |
+
args,
|
| 463 |
+
True,
|
| 464 |
+
accelerator,
|
| 465 |
+
src_path,
|
| 466 |
+
save_stable_diffusion_format,
|
| 467 |
+
use_safetensors,
|
| 468 |
+
save_dtype,
|
| 469 |
+
epoch,
|
| 470 |
+
num_train_epochs,
|
| 471 |
+
global_step,
|
| 472 |
+
accelerator.unwrap_model(text_encoder),
|
| 473 |
+
accelerator.unwrap_model(unet),
|
| 474 |
+
vae,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 478 |
+
|
| 479 |
+
is_main_process = accelerator.is_main_process
|
| 480 |
+
if is_main_process:
|
| 481 |
+
unet = accelerator.unwrap_model(unet)
|
| 482 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
| 483 |
+
|
| 484 |
+
accelerator.end_training()
|
| 485 |
+
|
| 486 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 487 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 488 |
+
|
| 489 |
+
del accelerator # この後メモリを使うのでこれは消す
|
| 490 |
+
|
| 491 |
+
if is_main_process:
|
| 492 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 493 |
+
train_util.save_sd_model_on_train_end(
|
| 494 |
+
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
| 495 |
+
)
|
| 496 |
+
logger.info("model saved.")
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 500 |
+
parser = argparse.ArgumentParser()
|
| 501 |
+
|
| 502 |
+
add_logging_arguments(parser)
|
| 503 |
+
train_util.add_sd_models_arguments(parser)
|
| 504 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
| 505 |
+
train_util.add_training_arguments(parser, False)
|
| 506 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 507 |
+
train_util.add_sd_saving_arguments(parser)
|
| 508 |
+
train_util.add_optimizer_arguments(parser)
|
| 509 |
+
config_util.add_config_arguments(parser)
|
| 510 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 511 |
+
|
| 512 |
+
parser.add_argument(
|
| 513 |
+
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
| 514 |
+
)
|
| 515 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
| 516 |
+
parser.add_argument(
|
| 517 |
+
"--learning_rate_te",
|
| 518 |
+
type=float,
|
| 519 |
+
default=None,
|
| 520 |
+
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
| 521 |
+
)
|
| 522 |
+
parser.add_argument(
|
| 523 |
+
"--no_half_vae",
|
| 524 |
+
action="store_true",
|
| 525 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
return parser
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
if __name__ == "__main__":
|
| 532 |
+
parser = setup_parser()
|
| 533 |
+
|
| 534 |
+
args = parser.parse_args()
|
| 535 |
+
train_util.verify_command_line_training_args(args)
|
| 536 |
+
args = train_util.read_config_from_file(args, parser)
|
| 537 |
+
|
| 538 |
+
train(args)
|
gen_img.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
gen_img_diffusers.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==0.30.0
|
| 2 |
+
transformers==4.44.0
|
| 3 |
+
diffusers[torch]==0.25.0
|
| 4 |
+
ftfy==6.1.1
|
| 5 |
+
# albumentations==1.3.0
|
| 6 |
+
opencv-python==4.8.1.78
|
| 7 |
+
einops==0.7.0
|
| 8 |
+
pytorch-lightning==1.9.0
|
| 9 |
+
bitsandbytes==0.44.0
|
| 10 |
+
prodigyopt==1.0
|
| 11 |
+
lion-pytorch==0.0.6
|
| 12 |
+
tensorboard
|
| 13 |
+
safetensors==0.4.2
|
| 14 |
+
# gradio==3.16.2
|
| 15 |
+
altair==4.2.2
|
| 16 |
+
easygui==0.98.3
|
| 17 |
+
toml==0.10.2
|
| 18 |
+
voluptuous==0.13.1
|
| 19 |
+
huggingface-hub==0.24.5
|
| 20 |
+
# for Image utils
|
| 21 |
+
imagesize==1.4.1
|
| 22 |
+
# for BLIP captioning
|
| 23 |
+
# requests==2.28.2
|
| 24 |
+
# timm==0.6.12
|
| 25 |
+
# fairscale==0.4.13
|
| 26 |
+
# for WD14 captioning (tensorflow)
|
| 27 |
+
# tensorflow==2.10.1
|
| 28 |
+
# for WD14 captioning (onnx)
|
| 29 |
+
# onnx==1.15.0
|
| 30 |
+
# onnxruntime-gpu==1.17.1
|
| 31 |
+
# onnxruntime==1.17.1
|
| 32 |
+
# for cuda 12.1(default 11.8)
|
| 33 |
+
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
| 34 |
+
|
| 35 |
+
# this is for onnx:
|
| 36 |
+
# protobuf==3.20.3
|
| 37 |
+
# open clip for SDXL
|
| 38 |
+
# open-clip-torch==2.20.0
|
| 39 |
+
# For logging
|
| 40 |
+
rich==13.7.0
|
| 41 |
+
# for kohya_ss library
|
| 42 |
+
-e .
|
sdxl_gen_img.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
sdxl_minimal_inference.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う
|
| 2 |
+
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import datetime
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
from einops import repeat
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 14 |
+
|
| 15 |
+
init_ipex()
|
| 16 |
+
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import CLIPTokenizer
|
| 19 |
+
from diffusers import EulerDiscreteScheduler
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# import open_clip
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
|
| 25 |
+
from library import model_util, sdxl_model_util
|
| 26 |
+
import networks.lora as lora
|
| 27 |
+
from library.utils import setup_logging
|
| 28 |
+
|
| 29 |
+
setup_logging()
|
| 30 |
+
import logging
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
| 35 |
+
# scheduler: The settings around here seem to be the same as SD1/2
|
| 36 |
+
SCHEDULER_LINEAR_START = 0.00085
|
| 37 |
+
SCHEDULER_LINEAR_END = 0.0120
|
| 38 |
+
SCHEDULER_TIMESTEPS = 1000
|
| 39 |
+
SCHEDLER_SCHEDULE = "scaled_linear"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Time EmbeddingはDiffusersからのコピー
|
| 43 |
+
# Time Embedding is copied from Diffusers
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 47 |
+
"""
|
| 48 |
+
Create sinusoidal timestep embeddings.
|
| 49 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 50 |
+
These may be fractional.
|
| 51 |
+
:param dim: the dimension of the output.
|
| 52 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 53 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 54 |
+
"""
|
| 55 |
+
if not repeat_only:
|
| 56 |
+
half = dim // 2
|
| 57 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 58 |
+
device=timesteps.device
|
| 59 |
+
)
|
| 60 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 61 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 62 |
+
if dim % 2:
|
| 63 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 64 |
+
else:
|
| 65 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
| 66 |
+
return embedding
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def get_timestep_embedding(x, outdim):
|
| 70 |
+
assert len(x.shape) == 2
|
| 71 |
+
b, dims = x.shape[0], x.shape[1]
|
| 72 |
+
# x = rearrange(x, "b d -> (b d)")
|
| 73 |
+
x = torch.flatten(x)
|
| 74 |
+
emb = timestep_embedding(x, outdim)
|
| 75 |
+
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
|
| 76 |
+
emb = torch.reshape(emb, (b, dims * outdim))
|
| 77 |
+
return emb
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
|
| 82 |
+
|
| 83 |
+
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
|
| 84 |
+
target_height = 1024
|
| 85 |
+
target_width = 1024
|
| 86 |
+
original_height = target_height
|
| 87 |
+
original_width = target_width
|
| 88 |
+
crop_top = 0
|
| 89 |
+
crop_left = 0
|
| 90 |
+
|
| 91 |
+
steps = 50
|
| 92 |
+
guidance_scale = 7
|
| 93 |
+
seed = None # 1
|
| 94 |
+
|
| 95 |
+
DEVICE = get_preferred_device()
|
| 96 |
+
DTYPE = torch.float16 # bfloat16 may work
|
| 97 |
+
|
| 98 |
+
parser = argparse.ArgumentParser()
|
| 99 |
+
parser.add_argument("--ckpt_path", type=str, required=True)
|
| 100 |
+
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
| 101 |
+
parser.add_argument("--prompt2", type=str, default=None)
|
| 102 |
+
parser.add_argument("--negative_prompt", type=str, default="")
|
| 103 |
+
parser.add_argument("--output_dir", type=str, default=".")
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--lora_weights",
|
| 106 |
+
type=str,
|
| 107 |
+
nargs="*",
|
| 108 |
+
default=[],
|
| 109 |
+
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument("--interactive", action="store_true")
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
|
| 114 |
+
if args.prompt2 is None:
|
| 115 |
+
args.prompt2 = args.prompt
|
| 116 |
+
|
| 117 |
+
# HuggingFaceのmodel id
|
| 118 |
+
text_encoder_1_name = "openai/clip-vit-large-patch14"
|
| 119 |
+
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
| 120 |
+
|
| 121 |
+
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
|
| 122 |
+
# Load checkpoint. For model conversion, see this function
|
| 123 |
+
|
| 124 |
+
# 本体RAMが少ない場合はGPUにロードするといいかも
|
| 125 |
+
# If the main RAM is small, it may be better to load it on the GPU
|
| 126 |
+
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
| 127 |
+
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
| 131 |
+
# In SDXL, Text Encoder 1 is also using HuggingFace's
|
| 132 |
+
|
| 133 |
+
# Text Encoder 2はSDXL本体ではopen_clipを使っている
|
| 134 |
+
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
|
| 135 |
+
# 重みの変換コードはSD2とほぼ同じ
|
| 136 |
+
# In SDXL, Text Encoder 2 is using open_clip
|
| 137 |
+
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
|
| 138 |
+
# The weight conversion code is almost the same as SD2
|
| 139 |
+
|
| 140 |
+
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
|
| 141 |
+
# fp16でNaNが出やすいようだ
|
| 142 |
+
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
|
| 143 |
+
# NaN seems to be more likely to occur in fp16
|
| 144 |
+
|
| 145 |
+
unet.to(DEVICE, dtype=DTYPE)
|
| 146 |
+
unet.eval()
|
| 147 |
+
|
| 148 |
+
vae_dtype = DTYPE
|
| 149 |
+
if DTYPE == torch.float16:
|
| 150 |
+
logger.info("use float32 for vae")
|
| 151 |
+
vae_dtype = torch.float32
|
| 152 |
+
vae.to(DEVICE, dtype=vae_dtype)
|
| 153 |
+
vae.eval()
|
| 154 |
+
|
| 155 |
+
text_model1.to(DEVICE, dtype=DTYPE)
|
| 156 |
+
text_model1.eval()
|
| 157 |
+
text_model2.to(DEVICE, dtype=DTYPE)
|
| 158 |
+
text_model2.eval()
|
| 159 |
+
|
| 160 |
+
unet.set_use_memory_efficient_attention(True, False)
|
| 161 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
| 162 |
+
vae.set_use_memory_efficient_attention_xformers(True)
|
| 163 |
+
|
| 164 |
+
# Tokenizers
|
| 165 |
+
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
| 166 |
+
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
|
| 167 |
+
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
|
| 168 |
+
|
| 169 |
+
# LoRA
|
| 170 |
+
for weights_file in args.lora_weights:
|
| 171 |
+
if ";" in weights_file:
|
| 172 |
+
weights_file, multiplier = weights_file.split(";")
|
| 173 |
+
multiplier = float(multiplier)
|
| 174 |
+
else:
|
| 175 |
+
multiplier = 1.0
|
| 176 |
+
|
| 177 |
+
lora_model, weights_sd = lora.create_network_from_weights(
|
| 178 |
+
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
|
| 179 |
+
)
|
| 180 |
+
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
|
| 181 |
+
|
| 182 |
+
# scheduler
|
| 183 |
+
scheduler = EulerDiscreteScheduler(
|
| 184 |
+
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
| 185 |
+
beta_start=SCHEDULER_LINEAR_START,
|
| 186 |
+
beta_end=SCHEDULER_LINEAR_END,
|
| 187 |
+
beta_schedule=SCHEDLER_SCHEDULE,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def generate_image(prompt, prompt2, negative_prompt, seed=None):
|
| 191 |
+
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
|
| 192 |
+
# prepare embedding
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
# vector
|
| 195 |
+
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
| 196 |
+
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
| 197 |
+
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
| 198 |
+
# logger.info("emb1", emb1.shape)
|
| 199 |
+
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
| 200 |
+
uc_vector = c_vector.clone().to(
|
| 201 |
+
DEVICE, dtype=DTYPE
|
| 202 |
+
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
| 203 |
+
|
| 204 |
+
# crossattn
|
| 205 |
+
|
| 206 |
+
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
|
| 207 |
+
def call_text_encoder(text, text2):
|
| 208 |
+
# text encoder 1
|
| 209 |
+
batch_encoding = tokenizer1(
|
| 210 |
+
text,
|
| 211 |
+
truncation=True,
|
| 212 |
+
return_length=True,
|
| 213 |
+
return_overflowing_tokens=False,
|
| 214 |
+
padding="max_length",
|
| 215 |
+
return_tensors="pt",
|
| 216 |
+
)
|
| 217 |
+
tokens = batch_encoding["input_ids"].to(DEVICE)
|
| 218 |
+
|
| 219 |
+
with torch.no_grad():
|
| 220 |
+
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
|
| 221 |
+
text_embedding1 = enc_out["hidden_states"][11]
|
| 222 |
+
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
|
| 223 |
+
|
| 224 |
+
# text encoder 2
|
| 225 |
+
# tokens = tokenizer2(text2).to(DEVICE)
|
| 226 |
+
tokens = tokenizer2(
|
| 227 |
+
text,
|
| 228 |
+
truncation=True,
|
| 229 |
+
return_length=True,
|
| 230 |
+
return_overflowing_tokens=False,
|
| 231 |
+
padding="max_length",
|
| 232 |
+
return_tensors="pt",
|
| 233 |
+
)
|
| 234 |
+
tokens = batch_encoding["input_ids"].to(DEVICE)
|
| 235 |
+
|
| 236 |
+
with torch.no_grad():
|
| 237 |
+
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
| 238 |
+
text_embedding2_penu = enc_out["hidden_states"][-2]
|
| 239 |
+
# logger.info("hidden_states2", text_embedding2_penu.shape)
|
| 240 |
+
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
| 241 |
+
|
| 242 |
+
# 連結して終了 concat and finish
|
| 243 |
+
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
| 244 |
+
return text_embedding, text_embedding2_pool
|
| 245 |
+
|
| 246 |
+
# cond
|
| 247 |
+
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
| 248 |
+
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
| 249 |
+
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
| 250 |
+
|
| 251 |
+
# uncond
|
| 252 |
+
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
|
| 253 |
+
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
|
| 254 |
+
|
| 255 |
+
text_embeddings = torch.cat([uc_ctx, c_ctx])
|
| 256 |
+
vector_embeddings = torch.cat([uc_vector, c_vector])
|
| 257 |
+
|
| 258 |
+
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
|
| 259 |
+
|
| 260 |
+
if seed is not None:
|
| 261 |
+
random.seed(seed)
|
| 262 |
+
np.random.seed(seed)
|
| 263 |
+
torch.manual_seed(seed)
|
| 264 |
+
torch.cuda.manual_seed_all(seed)
|
| 265 |
+
|
| 266 |
+
# # random generator for initial noise
|
| 267 |
+
# generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 268 |
+
generator = None
|
| 269 |
+
else:
|
| 270 |
+
generator = None
|
| 271 |
+
|
| 272 |
+
# get the initial random noise unless the user supplied it
|
| 273 |
+
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
|
| 274 |
+
# SDXL creates latents in CPU, Diffusers creates latents in target device
|
| 275 |
+
latents_shape = (1, 4, target_height // 8, target_width // 8)
|
| 276 |
+
latents = torch.randn(
|
| 277 |
+
latents_shape,
|
| 278 |
+
generator=generator,
|
| 279 |
+
device="cpu",
|
| 280 |
+
dtype=torch.float32,
|
| 281 |
+
).to(DEVICE, dtype=DTYPE)
|
| 282 |
+
|
| 283 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 284 |
+
latents = latents * scheduler.init_noise_sigma
|
| 285 |
+
|
| 286 |
+
# set timesteps
|
| 287 |
+
scheduler.set_timesteps(steps, DEVICE)
|
| 288 |
+
|
| 289 |
+
# このへんはDiffusersからのコピペ
|
| 290 |
+
# Copy from Diffusers
|
| 291 |
+
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
|
| 292 |
+
num_latent_input = 2
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
for i, t in enumerate(tqdm(timesteps)):
|
| 295 |
+
# expand the latents if we are doing classifier free guidance
|
| 296 |
+
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
| 297 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 298 |
+
|
| 299 |
+
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
| 300 |
+
|
| 301 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
|
| 302 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 303 |
+
|
| 304 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 305 |
+
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 306 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 307 |
+
|
| 308 |
+
# latents = 1 / 0.18215 * latents
|
| 309 |
+
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
| 310 |
+
latents = latents.to(vae_dtype)
|
| 311 |
+
image = vae.decode(latents).sample
|
| 312 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 313 |
+
|
| 314 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 315 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 316 |
+
|
| 317 |
+
# image = self.numpy_to_pil(image)
|
| 318 |
+
image = (image * 255).round().astype("uint8")
|
| 319 |
+
image = [Image.fromarray(im) for im in image]
|
| 320 |
+
|
| 321 |
+
# 保存して終了 save and finish
|
| 322 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 323 |
+
for i, img in enumerate(image):
|
| 324 |
+
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
|
| 325 |
+
|
| 326 |
+
if not args.interactive:
|
| 327 |
+
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
|
| 328 |
+
else:
|
| 329 |
+
# loop for interactive
|
| 330 |
+
while True:
|
| 331 |
+
prompt = input("prompt: ")
|
| 332 |
+
if prompt == "":
|
| 333 |
+
break
|
| 334 |
+
prompt2 = input("prompt2: ")
|
| 335 |
+
if prompt2 == "":
|
| 336 |
+
prompt2 = prompt
|
| 337 |
+
negative_prompt = input("negative prompt: ")
|
| 338 |
+
seed = input("seed: ")
|
| 339 |
+
if seed == "":
|
| 340 |
+
seed = None
|
| 341 |
+
else:
|
| 342 |
+
seed = int(seed)
|
| 343 |
+
generate_image(prompt, prompt2, negative_prompt, seed)
|
| 344 |
+
|
| 345 |
+
logger.info("Done!")
|
sdxl_train.py
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# training with captions
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from multiprocessing import Value
|
| 7 |
+
from typing import List
|
| 8 |
+
import toml
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
init_ipex()
|
| 17 |
+
|
| 18 |
+
from accelerate.utils import set_seed
|
| 19 |
+
from diffusers import DDPMScheduler
|
| 20 |
+
from library import deepspeed_utils, sdxl_model_util
|
| 21 |
+
|
| 22 |
+
import library.train_util as train_util
|
| 23 |
+
|
| 24 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 25 |
+
|
| 26 |
+
setup_logging()
|
| 27 |
+
import logging
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
import library.config_util as config_util
|
| 32 |
+
import library.sdxl_train_util as sdxl_train_util
|
| 33 |
+
from library.config_util import (
|
| 34 |
+
ConfigSanitizer,
|
| 35 |
+
BlueprintGenerator,
|
| 36 |
+
)
|
| 37 |
+
import library.custom_train_functions as custom_train_functions
|
| 38 |
+
from library.custom_train_functions import (
|
| 39 |
+
apply_snr_weight,
|
| 40 |
+
prepare_scheduler_for_custom_training,
|
| 41 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 42 |
+
add_v_prediction_like_loss,
|
| 43 |
+
apply_debiased_estimation,
|
| 44 |
+
apply_masked_loss,
|
| 45 |
+
)
|
| 46 |
+
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
|
| 53 |
+
block_params = [[] for _ in range(len(block_lrs))]
|
| 54 |
+
|
| 55 |
+
for i, (name, param) in enumerate(unet.named_parameters()):
|
| 56 |
+
if name.startswith("time_embed.") or name.startswith("label_emb."):
|
| 57 |
+
block_index = 0 # 0
|
| 58 |
+
elif name.startswith("input_blocks."): # 1-9
|
| 59 |
+
block_index = 1 + int(name.split(".")[1])
|
| 60 |
+
elif name.startswith("middle_block."): # 10-12
|
| 61 |
+
block_index = 10 + int(name.split(".")[1])
|
| 62 |
+
elif name.startswith("output_blocks."): # 13-21
|
| 63 |
+
block_index = 13 + int(name.split(".")[1])
|
| 64 |
+
elif name.startswith("out."): # 22
|
| 65 |
+
block_index = 22
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"unexpected parameter name: {name}")
|
| 68 |
+
|
| 69 |
+
block_params[block_index].append(param)
|
| 70 |
+
|
| 71 |
+
params_to_optimize = []
|
| 72 |
+
for i, params in enumerate(block_params):
|
| 73 |
+
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
|
| 74 |
+
continue
|
| 75 |
+
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
|
| 76 |
+
|
| 77 |
+
return params_to_optimize
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
| 81 |
+
names = []
|
| 82 |
+
block_index = 0
|
| 83 |
+
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
|
| 84 |
+
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
| 85 |
+
if block_lrs[block_index] == 0:
|
| 86 |
+
block_index += 1
|
| 87 |
+
continue
|
| 88 |
+
names.append(f"block{block_index}")
|
| 89 |
+
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
| 90 |
+
names.append("text_encoder1")
|
| 91 |
+
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
|
| 92 |
+
names.append("text_encoder2")
|
| 93 |
+
|
| 94 |
+
block_index += 1
|
| 95 |
+
|
| 96 |
+
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def train(args):
|
| 100 |
+
train_util.verify_training_args(args)
|
| 101 |
+
train_util.prepare_dataset_args(args, True)
|
| 102 |
+
sdxl_train_util.verify_sdxl_training_args(args)
|
| 103 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
| 104 |
+
setup_logging(args, reset=True)
|
| 105 |
+
|
| 106 |
+
assert (
|
| 107 |
+
not args.weighted_captions
|
| 108 |
+
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
| 109 |
+
assert (
|
| 110 |
+
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
| 111 |
+
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
| 112 |
+
|
| 113 |
+
if args.block_lr:
|
| 114 |
+
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
|
| 115 |
+
assert (
|
| 116 |
+
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
|
| 117 |
+
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
|
| 118 |
+
else:
|
| 119 |
+
block_lrs = None
|
| 120 |
+
|
| 121 |
+
cache_latents = args.cache_latents
|
| 122 |
+
use_dreambooth_method = args.in_json is None
|
| 123 |
+
|
| 124 |
+
if args.seed is not None:
|
| 125 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
| 126 |
+
|
| 127 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
| 128 |
+
|
| 129 |
+
# データセットを準備する
|
| 130 |
+
if args.dataset_class is None:
|
| 131 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
| 132 |
+
if args.dataset_config is not None:
|
| 133 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 134 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 135 |
+
ignored = ["train_data_dir", "in_json"]
|
| 136 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 137 |
+
logger.warning(
|
| 138 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無���されます: {0}".format(
|
| 139 |
+
", ".join(ignored)
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
if use_dreambooth_method:
|
| 144 |
+
logger.info("Using DreamBooth method.")
|
| 145 |
+
user_config = {
|
| 146 |
+
"datasets": [
|
| 147 |
+
{
|
| 148 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
| 149 |
+
args.train_data_dir, args.reg_data_dir
|
| 150 |
+
)
|
| 151 |
+
}
|
| 152 |
+
]
|
| 153 |
+
}
|
| 154 |
+
else:
|
| 155 |
+
logger.info("Training with captions.")
|
| 156 |
+
user_config = {
|
| 157 |
+
"datasets": [
|
| 158 |
+
{
|
| 159 |
+
"subsets": [
|
| 160 |
+
{
|
| 161 |
+
"image_dir": args.train_data_dir,
|
| 162 |
+
"metadata_file": args.in_json,
|
| 163 |
+
}
|
| 164 |
+
]
|
| 165 |
+
}
|
| 166 |
+
]
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
| 170 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 171 |
+
else:
|
| 172 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
|
| 173 |
+
|
| 174 |
+
current_epoch = Value("i", 0)
|
| 175 |
+
current_step = Value("i", 0)
|
| 176 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 177 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 178 |
+
|
| 179 |
+
train_dataset_group.verify_bucket_reso_steps(32)
|
| 180 |
+
|
| 181 |
+
if args.debug_dataset:
|
| 182 |
+
train_util.debug_dataset(train_dataset_group, True)
|
| 183 |
+
return
|
| 184 |
+
if len(train_dataset_group) == 0:
|
| 185 |
+
logger.error(
|
| 186 |
+
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
| 187 |
+
)
|
| 188 |
+
return
|
| 189 |
+
|
| 190 |
+
if cache_latents:
|
| 191 |
+
assert (
|
| 192 |
+
train_dataset_group.is_latent_cacheable()
|
| 193 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 194 |
+
|
| 195 |
+
if args.cache_text_encoder_outputs:
|
| 196 |
+
assert (
|
| 197 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
| 198 |
+
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
| 199 |
+
|
| 200 |
+
# acceleratorを準備する
|
| 201 |
+
logger.info("prepare accelerator")
|
| 202 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 203 |
+
|
| 204 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 205 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 206 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 207 |
+
|
| 208 |
+
# モデルを読み込む
|
| 209 |
+
(
|
| 210 |
+
load_stable_diffusion_format,
|
| 211 |
+
text_encoder1,
|
| 212 |
+
text_encoder2,
|
| 213 |
+
vae,
|
| 214 |
+
unet,
|
| 215 |
+
logit_scale,
|
| 216 |
+
ckpt_info,
|
| 217 |
+
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
| 218 |
+
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
| 219 |
+
|
| 220 |
+
# verify load/save model formats
|
| 221 |
+
if load_stable_diffusion_format:
|
| 222 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
| 223 |
+
src_diffusers_model_path = None
|
| 224 |
+
else:
|
| 225 |
+
src_stable_diffusion_ckpt = None
|
| 226 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
| 227 |
+
|
| 228 |
+
if args.save_model_as is None:
|
| 229 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
| 230 |
+
use_safetensors = args.use_safetensors
|
| 231 |
+
else:
|
| 232 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
| 233 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
| 234 |
+
# assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
| 235 |
+
|
| 236 |
+
# Diffusers版のxformers使用フラグを設定する関数
|
| 237 |
+
def set_diffusers_xformers_flag(model, valid):
|
| 238 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 239 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 240 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
| 241 |
+
|
| 242 |
+
for child in module.children():
|
| 243 |
+
fn_recursive_set_mem_eff(child)
|
| 244 |
+
|
| 245 |
+
fn_recursive_set_mem_eff(model)
|
| 246 |
+
|
| 247 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 248 |
+
if args.diffusers_xformers:
|
| 249 |
+
# もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
|
| 250 |
+
accelerator.print("Use xformers by Diffusers")
|
| 251 |
+
# set_diffusers_xformers_flag(unet, True)
|
| 252 |
+
set_diffusers_xformers_flag(vae, True)
|
| 253 |
+
else:
|
| 254 |
+
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
|
| 255 |
+
accelerator.print("Disable Diffusers' xformers")
|
| 256 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 257 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
| 258 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
| 259 |
+
|
| 260 |
+
# 学習を準備する
|
| 261 |
+
if cache_latents:
|
| 262 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 263 |
+
vae.requires_grad_(False)
|
| 264 |
+
vae.eval()
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 267 |
+
vae.to("cpu")
|
| 268 |
+
clean_memory_on_device(accelerator.device)
|
| 269 |
+
|
| 270 |
+
accelerator.wait_for_everyone()
|
| 271 |
+
|
| 272 |
+
# 学習を準備する:モデルを適切な状態にする
|
| 273 |
+
if args.gradient_checkpointing:
|
| 274 |
+
unet.enable_gradient_checkpointing()
|
| 275 |
+
train_unet = args.learning_rate != 0
|
| 276 |
+
train_text_encoder1 = False
|
| 277 |
+
train_text_encoder2 = False
|
| 278 |
+
|
| 279 |
+
if args.train_text_encoder:
|
| 280 |
+
# TODO each option for two text encoders?
|
| 281 |
+
accelerator.print("enable text encoder training")
|
| 282 |
+
if args.gradient_checkpointing:
|
| 283 |
+
text_encoder1.gradient_checkpointing_enable()
|
| 284 |
+
text_encoder2.gradient_checkpointing_enable()
|
| 285 |
+
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
| 286 |
+
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
| 287 |
+
train_text_encoder1 = lr_te1 != 0
|
| 288 |
+
train_text_encoder2 = lr_te2 != 0
|
| 289 |
+
|
| 290 |
+
# caching one text encoder output is not supported
|
| 291 |
+
if not train_text_encoder1:
|
| 292 |
+
text_encoder1.to(weight_dtype)
|
| 293 |
+
if not train_text_encoder2:
|
| 294 |
+
text_encoder2.to(weight_dtype)
|
| 295 |
+
text_encoder1.requires_grad_(train_text_encoder1)
|
| 296 |
+
text_encoder2.requires_grad_(train_text_encoder2)
|
| 297 |
+
text_encoder1.train(train_text_encoder1)
|
| 298 |
+
text_encoder2.train(train_text_encoder2)
|
| 299 |
+
else:
|
| 300 |
+
text_encoder1.to(weight_dtype)
|
| 301 |
+
text_encoder2.to(weight_dtype)
|
| 302 |
+
text_encoder1.requires_grad_(False)
|
| 303 |
+
text_encoder2.requires_grad_(False)
|
| 304 |
+
text_encoder1.eval()
|
| 305 |
+
text_encoder2.eval()
|
| 306 |
+
|
| 307 |
+
# TextEncoderの出力をキャッシュする
|
| 308 |
+
if args.cache_text_encoder_outputs:
|
| 309 |
+
# Text Encodes are eval and no grad
|
| 310 |
+
with torch.no_grad(), accelerator.autocast():
|
| 311 |
+
train_dataset_group.cache_text_encoder_outputs(
|
| 312 |
+
(tokenizer1, tokenizer2),
|
| 313 |
+
(text_encoder1, text_encoder2),
|
| 314 |
+
accelerator.device,
|
| 315 |
+
None,
|
| 316 |
+
args.cache_text_encoder_outputs_to_disk,
|
| 317 |
+
accelerator.is_main_process,
|
| 318 |
+
)
|
| 319 |
+
accelerator.wait_for_everyone()
|
| 320 |
+
|
| 321 |
+
if not cache_latents:
|
| 322 |
+
vae.requires_grad_(False)
|
| 323 |
+
vae.eval()
|
| 324 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 325 |
+
|
| 326 |
+
unet.requires_grad_(train_unet)
|
| 327 |
+
if not train_unet:
|
| 328 |
+
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
| 329 |
+
|
| 330 |
+
training_models = []
|
| 331 |
+
params_to_optimize = []
|
| 332 |
+
if train_unet:
|
| 333 |
+
training_models.append(unet)
|
| 334 |
+
if block_lrs is None:
|
| 335 |
+
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
| 336 |
+
else:
|
| 337 |
+
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
|
| 338 |
+
|
| 339 |
+
if train_text_encoder1:
|
| 340 |
+
training_models.append(text_encoder1)
|
| 341 |
+
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
| 342 |
+
if train_text_encoder2:
|
| 343 |
+
training_models.append(text_encoder2)
|
| 344 |
+
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
| 345 |
+
|
| 346 |
+
# calculate number of trainable parameters
|
| 347 |
+
n_params = 0
|
| 348 |
+
for group in params_to_optimize:
|
| 349 |
+
for p in group["params"]:
|
| 350 |
+
n_params += p.numel()
|
| 351 |
+
|
| 352 |
+
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
| 353 |
+
accelerator.print(f"number of models: {len(training_models)}")
|
| 354 |
+
accelerator.print(f"number of trainable parameters: {n_params}")
|
| 355 |
+
|
| 356 |
+
# 学習に必要なクラスを準備する
|
| 357 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 358 |
+
|
| 359 |
+
if args.fused_optimizer_groups:
|
| 360 |
+
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
|
| 361 |
+
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
|
| 362 |
+
# This balances memory usage and management complexity.
|
| 363 |
+
|
| 364 |
+
# calculate total number of parameters
|
| 365 |
+
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
|
| 366 |
+
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
|
| 367 |
+
|
| 368 |
+
# split params into groups, keeping the learning rate the same for all params in a group
|
| 369 |
+
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
|
| 370 |
+
grouped_params = []
|
| 371 |
+
param_group = []
|
| 372 |
+
param_group_lr = -1
|
| 373 |
+
for group in params_to_optimize:
|
| 374 |
+
lr = group["lr"]
|
| 375 |
+
for p in group["params"]:
|
| 376 |
+
# if the learning rate is different for different params, start a new group
|
| 377 |
+
if lr != param_group_lr:
|
| 378 |
+
if param_group:
|
| 379 |
+
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
| 380 |
+
param_group = []
|
| 381 |
+
param_group_lr = lr
|
| 382 |
+
|
| 383 |
+
param_group.append(p)
|
| 384 |
+
|
| 385 |
+
# if the group has enough parameters, start a new group
|
| 386 |
+
if len(param_group) == params_per_group:
|
| 387 |
+
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
| 388 |
+
param_group = []
|
| 389 |
+
param_group_lr = -1
|
| 390 |
+
|
| 391 |
+
if param_group:
|
| 392 |
+
grouped_params.append({"params": param_group, "lr": param_group_lr})
|
| 393 |
+
|
| 394 |
+
# prepare optimizers for each group
|
| 395 |
+
optimizers = []
|
| 396 |
+
for group in grouped_params:
|
| 397 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
|
| 398 |
+
optimizers.append(optimizer)
|
| 399 |
+
optimizer = optimizers[0] # avoid error in the following code
|
| 400 |
+
|
| 401 |
+
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
| 405 |
+
|
| 406 |
+
# dataloaderを準備する
|
| 407 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 408 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 409 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 410 |
+
train_dataset_group,
|
| 411 |
+
batch_size=1,
|
| 412 |
+
shuffle=True,
|
| 413 |
+
collate_fn=collator,
|
| 414 |
+
num_workers=n_workers,
|
| 415 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# 学習ステップ数を計算する
|
| 419 |
+
if args.max_train_epochs is not None:
|
| 420 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 421 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 422 |
+
)
|
| 423 |
+
accelerator.print(
|
| 424 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# データセット側にも学習ステップを送信
|
| 428 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 429 |
+
|
| 430 |
+
# lr schedulerを用意する
|
| 431 |
+
if args.fused_optimizer_groups:
|
| 432 |
+
# prepare lr schedulers for each optimizer
|
| 433 |
+
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
|
| 434 |
+
lr_scheduler = lr_schedulers[0] # avoid error in the following code
|
| 435 |
+
else:
|
| 436 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 437 |
+
|
| 438 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
| 439 |
+
if args.full_fp16:
|
| 440 |
+
assert (
|
| 441 |
+
args.mixed_precision == "fp16"
|
| 442 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 443 |
+
accelerator.print("enable full fp16 training.")
|
| 444 |
+
unet.to(weight_dtype)
|
| 445 |
+
text_encoder1.to(weight_dtype)
|
| 446 |
+
text_encoder2.to(weight_dtype)
|
| 447 |
+
elif args.full_bf16:
|
| 448 |
+
assert (
|
| 449 |
+
args.mixed_precision == "bf16"
|
| 450 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
| 451 |
+
accelerator.print("enable full bf16 training.")
|
| 452 |
+
unet.to(weight_dtype)
|
| 453 |
+
text_encoder1.to(weight_dtype)
|
| 454 |
+
text_encoder2.to(weight_dtype)
|
| 455 |
+
|
| 456 |
+
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
| 457 |
+
if train_text_encoder1:
|
| 458 |
+
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
| 459 |
+
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
| 460 |
+
|
| 461 |
+
if args.deepspeed:
|
| 462 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
| 463 |
+
args,
|
| 464 |
+
unet=unet if train_unet else None,
|
| 465 |
+
text_encoder1=text_encoder1 if train_text_encoder1 else None,
|
| 466 |
+
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
| 467 |
+
)
|
| 468 |
+
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
| 469 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 470 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
| 471 |
+
)
|
| 472 |
+
training_models = [ds_model]
|
| 473 |
+
|
| 474 |
+
else:
|
| 475 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 476 |
+
if train_unet:
|
| 477 |
+
unet = accelerator.prepare(unet)
|
| 478 |
+
if train_text_encoder1:
|
| 479 |
+
text_encoder1 = accelerator.prepare(text_encoder1)
|
| 480 |
+
if train_text_encoder2:
|
| 481 |
+
text_encoder2 = accelerator.prepare(text_encoder2)
|
| 482 |
+
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
| 483 |
+
|
| 484 |
+
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
| 485 |
+
if args.cache_text_encoder_outputs:
|
| 486 |
+
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
| 487 |
+
text_encoder1.to("cpu", dtype=torch.float32)
|
| 488 |
+
text_encoder2.to("cpu", dtype=torch.float32)
|
| 489 |
+
clean_memory_on_device(accelerator.device)
|
| 490 |
+
else:
|
| 491 |
+
# make sure Text Encoders are on GPU
|
| 492 |
+
text_encoder1.to(accelerator.device)
|
| 493 |
+
text_encoder2.to(accelerator.device)
|
| 494 |
+
|
| 495 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 496 |
+
if args.full_fp16:
|
| 497 |
+
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
| 498 |
+
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
| 499 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 500 |
+
|
| 501 |
+
# resumeする
|
| 502 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 503 |
+
|
| 504 |
+
if args.fused_backward_pass:
|
| 505 |
+
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
| 506 |
+
import library.adafactor_fused
|
| 507 |
+
|
| 508 |
+
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
| 509 |
+
for param_group in optimizer.param_groups:
|
| 510 |
+
for parameter in param_group["params"]:
|
| 511 |
+
if parameter.requires_grad:
|
| 512 |
+
|
| 513 |
+
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
|
| 514 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 515 |
+
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
| 516 |
+
optimizer.step_param(tensor, param_group)
|
| 517 |
+
tensor.grad = None
|
| 518 |
+
|
| 519 |
+
parameter.register_post_accumulate_grad_hook(__grad_hook)
|
| 520 |
+
|
| 521 |
+
elif args.fused_optimizer_groups:
|
| 522 |
+
# prepare for additional optimizers and lr schedulers
|
| 523 |
+
for i in range(1, len(optimizers)):
|
| 524 |
+
optimizers[i] = accelerator.prepare(optimizers[i])
|
| 525 |
+
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
| 526 |
+
|
| 527 |
+
# counters are used to determine when to step the optimizer
|
| 528 |
+
global optimizer_hooked_count
|
| 529 |
+
global num_parameters_per_group
|
| 530 |
+
global parameter_optimizer_map
|
| 531 |
+
|
| 532 |
+
optimizer_hooked_count = {}
|
| 533 |
+
num_parameters_per_group = [0] * len(optimizers)
|
| 534 |
+
parameter_optimizer_map = {}
|
| 535 |
+
|
| 536 |
+
for opt_idx, optimizer in enumerate(optimizers):
|
| 537 |
+
for param_group in optimizer.param_groups:
|
| 538 |
+
for parameter in param_group["params"]:
|
| 539 |
+
if parameter.requires_grad:
|
| 540 |
+
|
| 541 |
+
def optimizer_hook(parameter: torch.Tensor):
|
| 542 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 543 |
+
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
| 544 |
+
|
| 545 |
+
i = parameter_optimizer_map[parameter]
|
| 546 |
+
optimizer_hooked_count[i] += 1
|
| 547 |
+
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
| 548 |
+
optimizers[i].step()
|
| 549 |
+
optimizers[i].zero_grad(set_to_none=True)
|
| 550 |
+
|
| 551 |
+
parameter.register_post_accumulate_grad_hook(optimizer_hook)
|
| 552 |
+
parameter_optimizer_map[parameter] = opt_idx
|
| 553 |
+
num_parameters_per_group[opt_idx] += 1
|
| 554 |
+
|
| 555 |
+
# epoch数を計算する
|
| 556 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 557 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 558 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 559 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 560 |
+
|
| 561 |
+
# 学習する
|
| 562 |
+
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 563 |
+
accelerator.print("running training / 学習開始")
|
| 564 |
+
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
| 565 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 566 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 567 |
+
accelerator.print(
|
| 568 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
| 569 |
+
)
|
| 570 |
+
# accelerator.print(
|
| 571 |
+
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
| 572 |
+
# )
|
| 573 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 574 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 575 |
+
|
| 576 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 577 |
+
global_step = 0
|
| 578 |
+
|
| 579 |
+
noise_scheduler = DDPMScheduler(
|
| 580 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 581 |
+
)
|
| 582 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 583 |
+
if args.zero_terminal_snr:
|
| 584 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 585 |
+
|
| 586 |
+
if accelerator.is_main_process:
|
| 587 |
+
init_kwargs = {}
|
| 588 |
+
if args.wandb_run_name:
|
| 589 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 590 |
+
if args.log_tracker_config is not None:
|
| 591 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 592 |
+
accelerator.init_trackers(
|
| 593 |
+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
| 594 |
+
config=train_util.get_sanitized_config_or_none(args),
|
| 595 |
+
init_kwargs=init_kwargs,
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
# For --sample_at_first
|
| 599 |
+
sdxl_train_util.sample_images(
|
| 600 |
+
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
loss_recorder = train_util.LossRecorder()
|
| 604 |
+
for epoch in range(num_train_epochs):
|
| 605 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 606 |
+
current_epoch.value = epoch + 1
|
| 607 |
+
|
| 608 |
+
for m in training_models:
|
| 609 |
+
m.train()
|
| 610 |
+
|
| 611 |
+
for step, batch in enumerate(train_dataloader):
|
| 612 |
+
current_step.value = global_step
|
| 613 |
+
|
| 614 |
+
if args.fused_optimizer_groups:
|
| 615 |
+
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
| 616 |
+
|
| 617 |
+
with accelerator.accumulate(*training_models):
|
| 618 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 619 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 620 |
+
else:
|
| 621 |
+
with torch.no_grad():
|
| 622 |
+
# latentに変換
|
| 623 |
+
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
| 624 |
+
|
| 625 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
| 626 |
+
if torch.any(torch.isnan(latents)):
|
| 627 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
| 628 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
| 629 |
+
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
| 630 |
+
|
| 631 |
+
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
| 632 |
+
input_ids1 = batch["input_ids"]
|
| 633 |
+
input_ids2 = batch["input_ids2"]
|
| 634 |
+
with torch.set_grad_enabled(args.train_text_encoder):
|
| 635 |
+
# Get the text embedding for conditioning
|
| 636 |
+
# TODO support weighted captions
|
| 637 |
+
# if args.weighted_captions:
|
| 638 |
+
# encoder_hidden_states = get_weighted_text_embeddings(
|
| 639 |
+
# tokenizer,
|
| 640 |
+
# text_encoder,
|
| 641 |
+
# batch["captions"],
|
| 642 |
+
# accelerator.device,
|
| 643 |
+
# args.max_token_length // 75 if args.max_token_length else 1,
|
| 644 |
+
# clip_skip=args.clip_skip,
|
| 645 |
+
# )
|
| 646 |
+
# else:
|
| 647 |
+
input_ids1 = input_ids1.to(accelerator.device)
|
| 648 |
+
input_ids2 = input_ids2.to(accelerator.device)
|
| 649 |
+
# unwrap_model is fine for models not wrapped by accelerator
|
| 650 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
| 651 |
+
args.max_token_length,
|
| 652 |
+
input_ids1,
|
| 653 |
+
input_ids2,
|
| 654 |
+
tokenizer1,
|
| 655 |
+
tokenizer2,
|
| 656 |
+
text_encoder1,
|
| 657 |
+
text_encoder2,
|
| 658 |
+
None if not args.full_fp16 else weight_dtype,
|
| 659 |
+
accelerator=accelerator,
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
| 663 |
+
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
| 664 |
+
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
| 665 |
+
|
| 666 |
+
# # verify that the text encoder outputs are correct
|
| 667 |
+
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
| 668 |
+
# args.max_token_length,
|
| 669 |
+
# batch["input_ids"].to(text_encoder1.device),
|
| 670 |
+
# batch["input_ids2"].to(text_encoder1.device),
|
| 671 |
+
# tokenizer1,
|
| 672 |
+
# tokenizer2,
|
| 673 |
+
# text_encoder1,
|
| 674 |
+
# text_encoder2,
|
| 675 |
+
# None if not args.full_fp16 else weight_dtype,
|
| 676 |
+
# )
|
| 677 |
+
# b_size = encoder_hidden_states1.shape[0]
|
| 678 |
+
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 679 |
+
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 680 |
+
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 681 |
+
# logger.info("text encoder outputs verified")
|
| 682 |
+
|
| 683 |
+
# get size embeddings
|
| 684 |
+
orig_size = batch["original_sizes_hw"]
|
| 685 |
+
crop_size = batch["crop_top_lefts"]
|
| 686 |
+
target_size = batch["target_sizes_hw"]
|
| 687 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
| 688 |
+
|
| 689 |
+
# concat embeddings
|
| 690 |
+
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
| 691 |
+
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
| 692 |
+
|
| 693 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 694 |
+
# with noise offset and/or multires noise if specified
|
| 695 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
| 696 |
+
args, noise_scheduler, latents
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
| 700 |
+
|
| 701 |
+
# Predict the noise residual
|
| 702 |
+
with accelerator.autocast():
|
| 703 |
+
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
| 704 |
+
|
| 705 |
+
if args.v_parameterization:
|
| 706 |
+
# v-parameterization training
|
| 707 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 708 |
+
else:
|
| 709 |
+
target = noise
|
| 710 |
+
|
| 711 |
+
if (
|
| 712 |
+
args.min_snr_gamma
|
| 713 |
+
or args.scale_v_pred_loss_like_noise_pred
|
| 714 |
+
or args.v_pred_like_loss
|
| 715 |
+
or args.debiased_estimation_loss
|
| 716 |
+
or args.masked_loss
|
| 717 |
+
):
|
| 718 |
+
# do not mean over batch dimension for snr weight or scale v-pred loss
|
| 719 |
+
loss = train_util.conditional_loss(
|
| 720 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
| 721 |
+
)
|
| 722 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
| 723 |
+
loss = apply_masked_loss(loss, batch)
|
| 724 |
+
loss = loss.mean([1, 2, 3])
|
| 725 |
+
|
| 726 |
+
if args.min_snr_gamma:
|
| 727 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 728 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 729 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 730 |
+
if args.v_pred_like_loss:
|
| 731 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
| 732 |
+
if args.debiased_estimation_loss:
|
| 733 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 734 |
+
|
| 735 |
+
loss = loss.mean() # mean over batch dimension
|
| 736 |
+
else:
|
| 737 |
+
loss = train_util.conditional_loss(
|
| 738 |
+
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
accelerator.backward(loss)
|
| 742 |
+
|
| 743 |
+
if not (args.fused_backward_pass or args.fused_optimizer_groups):
|
| 744 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 745 |
+
params_to_clip = []
|
| 746 |
+
for m in training_models:
|
| 747 |
+
params_to_clip.extend(m.parameters())
|
| 748 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 749 |
+
|
| 750 |
+
optimizer.step()
|
| 751 |
+
lr_scheduler.step()
|
| 752 |
+
optimizer.zero_grad(set_to_none=True)
|
| 753 |
+
else:
|
| 754 |
+
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
| 755 |
+
lr_scheduler.step()
|
| 756 |
+
if args.fused_optimizer_groups:
|
| 757 |
+
for i in range(1, len(optimizers)):
|
| 758 |
+
lr_schedulers[i].step()
|
| 759 |
+
|
| 760 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 761 |
+
if accelerator.sync_gradients:
|
| 762 |
+
progress_bar.update(1)
|
| 763 |
+
global_step += 1
|
| 764 |
+
|
| 765 |
+
sdxl_train_util.sample_images(
|
| 766 |
+
accelerator,
|
| 767 |
+
args,
|
| 768 |
+
None,
|
| 769 |
+
global_step,
|
| 770 |
+
accelerator.device,
|
| 771 |
+
vae,
|
| 772 |
+
[tokenizer1, tokenizer2],
|
| 773 |
+
[text_encoder1, text_encoder2],
|
| 774 |
+
unet,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# 指定ステップごとにモデルを保存
|
| 778 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 779 |
+
accelerator.wait_for_everyone()
|
| 780 |
+
if accelerator.is_main_process:
|
| 781 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 782 |
+
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 783 |
+
args,
|
| 784 |
+
False,
|
| 785 |
+
accelerator,
|
| 786 |
+
src_path,
|
| 787 |
+
save_stable_diffusion_format,
|
| 788 |
+
use_safetensors,
|
| 789 |
+
save_dtype,
|
| 790 |
+
epoch,
|
| 791 |
+
num_train_epochs,
|
| 792 |
+
global_step,
|
| 793 |
+
accelerator.unwrap_model(text_encoder1),
|
| 794 |
+
accelerator.unwrap_model(text_encoder2),
|
| 795 |
+
accelerator.unwrap_model(unet),
|
| 796 |
+
vae,
|
| 797 |
+
logit_scale,
|
| 798 |
+
ckpt_info,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
| 802 |
+
if args.logging_dir is not None:
|
| 803 |
+
logs = {"loss": current_loss}
|
| 804 |
+
if block_lrs is None:
|
| 805 |
+
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
|
| 806 |
+
else:
|
| 807 |
+
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
|
| 808 |
+
|
| 809 |
+
accelerator.log(logs, step=global_step)
|
| 810 |
+
|
| 811 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 812 |
+
avr_loss: float = loss_recorder.moving_average
|
| 813 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 814 |
+
progress_bar.set_postfix(**logs)
|
| 815 |
+
|
| 816 |
+
if global_step >= args.max_train_steps:
|
| 817 |
+
break
|
| 818 |
+
|
| 819 |
+
if args.logging_dir is not None:
|
| 820 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 821 |
+
accelerator.log(logs, step=epoch + 1)
|
| 822 |
+
|
| 823 |
+
accelerator.wait_for_everyone()
|
| 824 |
+
|
| 825 |
+
if args.save_every_n_epochs is not None:
|
| 826 |
+
if accelerator.is_main_process:
|
| 827 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 828 |
+
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 829 |
+
args,
|
| 830 |
+
True,
|
| 831 |
+
accelerator,
|
| 832 |
+
src_path,
|
| 833 |
+
save_stable_diffusion_format,
|
| 834 |
+
use_safetensors,
|
| 835 |
+
save_dtype,
|
| 836 |
+
epoch,
|
| 837 |
+
num_train_epochs,
|
| 838 |
+
global_step,
|
| 839 |
+
accelerator.unwrap_model(text_encoder1),
|
| 840 |
+
accelerator.unwrap_model(text_encoder2),
|
| 841 |
+
accelerator.unwrap_model(unet),
|
| 842 |
+
vae,
|
| 843 |
+
logit_scale,
|
| 844 |
+
ckpt_info,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
sdxl_train_util.sample_images(
|
| 848 |
+
accelerator,
|
| 849 |
+
args,
|
| 850 |
+
epoch + 1,
|
| 851 |
+
global_step,
|
| 852 |
+
accelerator.device,
|
| 853 |
+
vae,
|
| 854 |
+
[tokenizer1, tokenizer2],
|
| 855 |
+
[text_encoder1, text_encoder2],
|
| 856 |
+
unet,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
is_main_process = accelerator.is_main_process
|
| 860 |
+
# if is_main_process:
|
| 861 |
+
unet = accelerator.unwrap_model(unet)
|
| 862 |
+
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
| 863 |
+
text_encoder2 = accelerator.unwrap_model(text_encoder2)
|
| 864 |
+
|
| 865 |
+
accelerator.end_training()
|
| 866 |
+
|
| 867 |
+
if args.save_state or args.save_state_on_train_end:
|
| 868 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 869 |
+
|
| 870 |
+
del accelerator # この後メモリを使うのでこれは消す
|
| 871 |
+
|
| 872 |
+
if is_main_process:
|
| 873 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 874 |
+
sdxl_train_util.save_sd_model_on_train_end(
|
| 875 |
+
args,
|
| 876 |
+
src_path,
|
| 877 |
+
save_stable_diffusion_format,
|
| 878 |
+
use_safetensors,
|
| 879 |
+
save_dtype,
|
| 880 |
+
epoch,
|
| 881 |
+
global_step,
|
| 882 |
+
text_encoder1,
|
| 883 |
+
text_encoder2,
|
| 884 |
+
unet,
|
| 885 |
+
vae,
|
| 886 |
+
logit_scale,
|
| 887 |
+
ckpt_info,
|
| 888 |
+
)
|
| 889 |
+
logger.info("model saved.")
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 893 |
+
parser = argparse.ArgumentParser()
|
| 894 |
+
|
| 895 |
+
add_logging_arguments(parser)
|
| 896 |
+
train_util.add_sd_models_arguments(parser)
|
| 897 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
| 898 |
+
train_util.add_training_arguments(parser, False)
|
| 899 |
+
train_util.add_masked_loss_arguments(parser)
|
| 900 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 901 |
+
train_util.add_sd_saving_arguments(parser)
|
| 902 |
+
train_util.add_optimizer_arguments(parser)
|
| 903 |
+
config_util.add_config_arguments(parser)
|
| 904 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 905 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 906 |
+
|
| 907 |
+
parser.add_argument(
|
| 908 |
+
"--learning_rate_te1",
|
| 909 |
+
type=float,
|
| 910 |
+
default=None,
|
| 911 |
+
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
|
| 912 |
+
)
|
| 913 |
+
parser.add_argument(
|
| 914 |
+
"--learning_rate_te2",
|
| 915 |
+
type=float,
|
| 916 |
+
default=None,
|
| 917 |
+
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
parser.add_argument(
|
| 921 |
+
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
| 922 |
+
)
|
| 923 |
+
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
| 924 |
+
parser.add_argument(
|
| 925 |
+
"--no_half_vae",
|
| 926 |
+
action="store_true",
|
| 927 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 928 |
+
)
|
| 929 |
+
parser.add_argument(
|
| 930 |
+
"--block_lr",
|
| 931 |
+
type=str,
|
| 932 |
+
default=None,
|
| 933 |
+
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
|
| 934 |
+
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
|
| 935 |
+
)
|
| 936 |
+
parser.add_argument(
|
| 937 |
+
"--fused_optimizer_groups",
|
| 938 |
+
type=int,
|
| 939 |
+
default=None,
|
| 940 |
+
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
| 941 |
+
)
|
| 942 |
+
return parser
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
if __name__ == "__main__":
|
| 946 |
+
parser = setup_parser()
|
| 947 |
+
|
| 948 |
+
args = parser.parse_args()
|
| 949 |
+
train_util.verify_command_line_training_args(args)
|
| 950 |
+
args = train_util.read_config_from_file(args, parser)
|
| 951 |
+
|
| 952 |
+
train(args)
|
sdxl_train_control_net_lllite.py
ADDED
|
@@ -0,0 +1,626 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
|
| 2 |
+
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
from multiprocessing import Value
|
| 11 |
+
from types import SimpleNamespace
|
| 12 |
+
import toml
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 18 |
+
|
| 19 |
+
init_ipex()
|
| 20 |
+
|
| 21 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 22 |
+
from accelerate.utils import set_seed
|
| 23 |
+
import accelerate
|
| 24 |
+
from diffusers import DDPMScheduler, ControlNetModel
|
| 25 |
+
from safetensors.torch import load_file
|
| 26 |
+
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
| 27 |
+
|
| 28 |
+
import library.model_util as model_util
|
| 29 |
+
import library.train_util as train_util
|
| 30 |
+
import library.config_util as config_util
|
| 31 |
+
from library.config_util import (
|
| 32 |
+
ConfigSanitizer,
|
| 33 |
+
BlueprintGenerator,
|
| 34 |
+
)
|
| 35 |
+
import library.huggingface_util as huggingface_util
|
| 36 |
+
import library.custom_train_functions as custom_train_functions
|
| 37 |
+
from library.custom_train_functions import (
|
| 38 |
+
add_v_prediction_like_loss,
|
| 39 |
+
apply_snr_weight,
|
| 40 |
+
prepare_scheduler_for_custom_training,
|
| 41 |
+
pyramid_noise_like,
|
| 42 |
+
apply_noise_offset,
|
| 43 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 44 |
+
apply_debiased_estimation,
|
| 45 |
+
)
|
| 46 |
+
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
| 47 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 48 |
+
|
| 49 |
+
setup_logging()
|
| 50 |
+
import logging
|
| 51 |
+
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# TODO 他のスクリプトと共通化する
|
| 56 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
| 57 |
+
logs = {
|
| 58 |
+
"loss/current": current_loss,
|
| 59 |
+
"loss/average": avr_loss,
|
| 60 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
| 64 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
| 65 |
+
|
| 66 |
+
return logs
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def train(args):
|
| 70 |
+
train_util.verify_training_args(args)
|
| 71 |
+
train_util.prepare_dataset_args(args, True)
|
| 72 |
+
sdxl_train_util.verify_sdxl_training_args(args)
|
| 73 |
+
setup_logging(args, reset=True)
|
| 74 |
+
|
| 75 |
+
cache_latents = args.cache_latents
|
| 76 |
+
use_user_config = args.dataset_config is not None
|
| 77 |
+
|
| 78 |
+
if args.seed is None:
|
| 79 |
+
args.seed = random.randint(0, 2**32)
|
| 80 |
+
set_seed(args.seed)
|
| 81 |
+
|
| 82 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
| 83 |
+
|
| 84 |
+
# データセットを準備する
|
| 85 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
| 86 |
+
if use_user_config:
|
| 87 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 88 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 89 |
+
ignored = ["train_data_dir", "conditioning_data_dir"]
|
| 90 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 91 |
+
logger.warning(
|
| 92 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 93 |
+
", ".join(ignored)
|
| 94 |
+
)
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
user_config = {
|
| 98 |
+
"datasets": [
|
| 99 |
+
{
|
| 100 |
+
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
| 101 |
+
args.train_data_dir,
|
| 102 |
+
args.conditioning_data_dir,
|
| 103 |
+
args.caption_extension,
|
| 104 |
+
)
|
| 105 |
+
}
|
| 106 |
+
]
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
| 110 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 111 |
+
|
| 112 |
+
current_epoch = Value("i", 0)
|
| 113 |
+
current_step = Value("i", 0)
|
| 114 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 115 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 116 |
+
|
| 117 |
+
train_dataset_group.verify_bucket_reso_steps(32)
|
| 118 |
+
|
| 119 |
+
if args.debug_dataset:
|
| 120 |
+
train_util.debug_dataset(train_dataset_group)
|
| 121 |
+
return
|
| 122 |
+
if len(train_dataset_group) == 0:
|
| 123 |
+
logger.error(
|
| 124 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
| 125 |
+
)
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
if cache_latents:
|
| 129 |
+
assert (
|
| 130 |
+
train_dataset_group.is_latent_cacheable()
|
| 131 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 132 |
+
else:
|
| 133 |
+
logger.warning(
|
| 134 |
+
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if args.cache_text_encoder_outputs:
|
| 138 |
+
assert (
|
| 139 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
| 140 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
| 141 |
+
|
| 142 |
+
# acceleratorを準備する
|
| 143 |
+
logger.info("prepare accelerator")
|
| 144 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 145 |
+
is_main_process = accelerator.is_main_process
|
| 146 |
+
|
| 147 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 148 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 149 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 150 |
+
|
| 151 |
+
# モデルを読み込む
|
| 152 |
+
(
|
| 153 |
+
load_stable_diffusion_format,
|
| 154 |
+
text_encoder1,
|
| 155 |
+
text_encoder2,
|
| 156 |
+
vae,
|
| 157 |
+
unet,
|
| 158 |
+
logit_scale,
|
| 159 |
+
ckpt_info,
|
| 160 |
+
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
| 161 |
+
|
| 162 |
+
# 学習を準備する
|
| 163 |
+
if cache_latents:
|
| 164 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 165 |
+
vae.requires_grad_(False)
|
| 166 |
+
vae.eval()
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
train_dataset_group.cache_latents(
|
| 169 |
+
vae,
|
| 170 |
+
args.vae_batch_size,
|
| 171 |
+
args.cache_latents_to_disk,
|
| 172 |
+
accelerator.is_main_process,
|
| 173 |
+
)
|
| 174 |
+
vae.to("cpu")
|
| 175 |
+
clean_memory_on_device(accelerator.device)
|
| 176 |
+
|
| 177 |
+
accelerator.wait_for_everyone()
|
| 178 |
+
|
| 179 |
+
# TextEncoderの出力をキャッシュする
|
| 180 |
+
if args.cache_text_encoder_outputs:
|
| 181 |
+
# Text Encodes are eval and no grad
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
train_dataset_group.cache_text_encoder_outputs(
|
| 184 |
+
(tokenizer1, tokenizer2),
|
| 185 |
+
(text_encoder1, text_encoder2),
|
| 186 |
+
accelerator.device,
|
| 187 |
+
None,
|
| 188 |
+
args.cache_text_encoder_outputs_to_disk,
|
| 189 |
+
accelerator.is_main_process,
|
| 190 |
+
)
|
| 191 |
+
accelerator.wait_for_everyone()
|
| 192 |
+
|
| 193 |
+
# prepare ControlNet-LLLite
|
| 194 |
+
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
|
| 195 |
+
|
| 196 |
+
if args.network_weights is not None:
|
| 197 |
+
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
| 198 |
+
with accelerate.init_empty_weights():
|
| 199 |
+
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
| 200 |
+
unet_lllite.to(accelerator.device, dtype=weight_dtype)
|
| 201 |
+
|
| 202 |
+
unet_sd = unet.state_dict()
|
| 203 |
+
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
|
| 204 |
+
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
|
| 205 |
+
else:
|
| 206 |
+
# cosumes large memory, so send to GPU before creating the LLLite model
|
| 207 |
+
accelerator.print("sending U-Net to GPU")
|
| 208 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 209 |
+
unet_sd = unet.state_dict()
|
| 210 |
+
|
| 211 |
+
# init LLLite weights
|
| 212 |
+
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
|
| 213 |
+
|
| 214 |
+
if args.lowram:
|
| 215 |
+
with accelerate.init_on_device(accelerator.device):
|
| 216 |
+
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
| 217 |
+
else:
|
| 218 |
+
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
|
| 219 |
+
unet_lllite.to(weight_dtype)
|
| 220 |
+
|
| 221 |
+
info = unet_lllite.load_lllite_weights(None, unet_sd)
|
| 222 |
+
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
|
| 223 |
+
del unet_sd, unet
|
| 224 |
+
|
| 225 |
+
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
|
| 226 |
+
del unet_lllite
|
| 227 |
+
|
| 228 |
+
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
|
| 229 |
+
|
| 230 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 231 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 232 |
+
|
| 233 |
+
if args.gradient_checkpointing:
|
| 234 |
+
unet.enable_gradient_checkpointing()
|
| 235 |
+
|
| 236 |
+
# 学習に必要なクラスを準備する
|
| 237 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 238 |
+
|
| 239 |
+
trainable_params = list(unet.prepare_params())
|
| 240 |
+
logger.info(f"trainable params count: {len(trainable_params)}")
|
| 241 |
+
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
| 242 |
+
|
| 243 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 244 |
+
|
| 245 |
+
# dataloaderを準備する
|
| 246 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 247 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 248 |
+
|
| 249 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 250 |
+
train_dataset_group,
|
| 251 |
+
batch_size=1,
|
| 252 |
+
shuffle=True,
|
| 253 |
+
collate_fn=collator,
|
| 254 |
+
num_workers=n_workers,
|
| 255 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# 学習ステップ数を計算する
|
| 259 |
+
if args.max_train_epochs is not None:
|
| 260 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 261 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 262 |
+
)
|
| 263 |
+
accelerator.print(
|
| 264 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# データセット側にも学習ステップを送信
|
| 268 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 269 |
+
|
| 270 |
+
# lr schedulerを用意する
|
| 271 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 272 |
+
|
| 273 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
| 274 |
+
# if args.full_fp16:
|
| 275 |
+
# assert (
|
| 276 |
+
# args.mixed_precision == "fp16"
|
| 277 |
+
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 278 |
+
# accelerator.print("enable full fp16 training.")
|
| 279 |
+
# unet.to(weight_dtype)
|
| 280 |
+
# elif args.full_bf16:
|
| 281 |
+
# assert (
|
| 282 |
+
# args.mixed_precision == "bf16"
|
| 283 |
+
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
| 284 |
+
# accelerator.print("enable full bf16 training.")
|
| 285 |
+
# unet.to(weight_dtype)
|
| 286 |
+
|
| 287 |
+
unet.to(weight_dtype)
|
| 288 |
+
|
| 289 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 290 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
| 291 |
+
|
| 292 |
+
if isinstance(unet, DDP):
|
| 293 |
+
unet._set_static_graph() # avoid error for multiple use of the parameter
|
| 294 |
+
|
| 295 |
+
if args.gradient_checkpointing:
|
| 296 |
+
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
| 297 |
+
else:
|
| 298 |
+
unet.eval()
|
| 299 |
+
|
| 300 |
+
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
| 301 |
+
if args.cache_text_encoder_outputs:
|
| 302 |
+
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
| 303 |
+
text_encoder1.to("cpu", dtype=torch.float32)
|
| 304 |
+
text_encoder2.to("cpu", dtype=torch.float32)
|
| 305 |
+
clean_memory_on_device(accelerator.device)
|
| 306 |
+
else:
|
| 307 |
+
# make sure Text Encoders are on GPU
|
| 308 |
+
text_encoder1.to(accelerator.device)
|
| 309 |
+
text_encoder2.to(accelerator.device)
|
| 310 |
+
|
| 311 |
+
if not cache_latents:
|
| 312 |
+
vae.requires_grad_(False)
|
| 313 |
+
vae.eval()
|
| 314 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 315 |
+
|
| 316 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 317 |
+
if args.full_fp16:
|
| 318 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 319 |
+
|
| 320 |
+
# resumeする
|
| 321 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 322 |
+
|
| 323 |
+
# epoch数を計算する
|
| 324 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 325 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 326 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 327 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 328 |
+
|
| 329 |
+
# 学習する
|
| 330 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
| 331 |
+
accelerator.print("running training / 学習開始")
|
| 332 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 333 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 334 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 335 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 336 |
+
accelerator.print(
|
| 337 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
| 338 |
+
)
|
| 339 |
+
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
| 340 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 341 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 342 |
+
|
| 343 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 344 |
+
global_step = 0
|
| 345 |
+
|
| 346 |
+
noise_scheduler = DDPMScheduler(
|
| 347 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 348 |
+
)
|
| 349 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 350 |
+
if args.zero_terminal_snr:
|
| 351 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 352 |
+
|
| 353 |
+
if accelerator.is_main_process:
|
| 354 |
+
init_kwargs = {}
|
| 355 |
+
if args.wandb_run_name:
|
| 356 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 357 |
+
if args.log_tracker_config is not None:
|
| 358 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 359 |
+
accelerator.init_trackers(
|
| 360 |
+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
loss_recorder = train_util.LossRecorder()
|
| 364 |
+
del train_dataset_group
|
| 365 |
+
|
| 366 |
+
# function for saving/removing
|
| 367 |
+
def save_model(
|
| 368 |
+
ckpt_name,
|
| 369 |
+
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
|
| 370 |
+
steps,
|
| 371 |
+
epoch_no,
|
| 372 |
+
force_sync_upload=False,
|
| 373 |
+
):
|
| 374 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 375 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 376 |
+
|
| 377 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
| 378 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
| 379 |
+
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
| 380 |
+
|
| 381 |
+
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
|
| 382 |
+
if args.huggingface_repo_id is not None:
|
| 383 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 384 |
+
|
| 385 |
+
def remove_model(old_ckpt_name):
|
| 386 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 387 |
+
if os.path.exists(old_ckpt_file):
|
| 388 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
| 389 |
+
os.remove(old_ckpt_file)
|
| 390 |
+
|
| 391 |
+
# training loop
|
| 392 |
+
for epoch in range(num_train_epochs):
|
| 393 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 394 |
+
current_epoch.value = epoch + 1
|
| 395 |
+
|
| 396 |
+
for step, batch in enumerate(train_dataloader):
|
| 397 |
+
current_step.value = global_step
|
| 398 |
+
with accelerator.accumulate(unet):
|
| 399 |
+
with torch.no_grad():
|
| 400 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 401 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 402 |
+
else:
|
| 403 |
+
# latentに変換
|
| 404 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
| 405 |
+
|
| 406 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
| 407 |
+
if torch.any(torch.isnan(latents)):
|
| 408 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
| 409 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
| 410 |
+
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
| 411 |
+
|
| 412 |
+
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
| 413 |
+
input_ids1 = batch["input_ids"]
|
| 414 |
+
input_ids2 = batch["input_ids2"]
|
| 415 |
+
with torch.no_grad():
|
| 416 |
+
# Get the text embedding for conditioning
|
| 417 |
+
input_ids1 = input_ids1.to(accelerator.device)
|
| 418 |
+
input_ids2 = input_ids2.to(accelerator.device)
|
| 419 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
| 420 |
+
args.max_token_length,
|
| 421 |
+
input_ids1,
|
| 422 |
+
input_ids2,
|
| 423 |
+
tokenizer1,
|
| 424 |
+
tokenizer2,
|
| 425 |
+
text_encoder1,
|
| 426 |
+
text_encoder2,
|
| 427 |
+
None if not args.full_fp16 else weight_dtype,
|
| 428 |
+
)
|
| 429 |
+
else:
|
| 430 |
+
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
| 431 |
+
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
| 432 |
+
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
| 433 |
+
|
| 434 |
+
# get size embeddings
|
| 435 |
+
orig_size = batch["original_sizes_hw"]
|
| 436 |
+
crop_size = batch["crop_top_lefts"]
|
| 437 |
+
target_size = batch["target_sizes_hw"]
|
| 438 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
| 439 |
+
|
| 440 |
+
# concat embeddings
|
| 441 |
+
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
| 442 |
+
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
| 443 |
+
|
| 444 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 445 |
+
# with noise offset and/or multires noise if specified
|
| 446 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
| 447 |
+
args, noise_scheduler, latents
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
| 451 |
+
|
| 452 |
+
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
| 453 |
+
|
| 454 |
+
with accelerator.autocast():
|
| 455 |
+
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
| 456 |
+
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
| 457 |
+
|
| 458 |
+
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
| 459 |
+
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
|
| 460 |
+
|
| 461 |
+
if args.v_parameterization:
|
| 462 |
+
# v-parameterization training
|
| 463 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 464 |
+
else:
|
| 465 |
+
target = noise
|
| 466 |
+
|
| 467 |
+
loss = train_util.conditional_loss(
|
| 468 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
| 469 |
+
)
|
| 470 |
+
loss = loss.mean([1, 2, 3])
|
| 471 |
+
|
| 472 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 473 |
+
loss = loss * loss_weights
|
| 474 |
+
|
| 475 |
+
if args.min_snr_gamma:
|
| 476 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 477 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 478 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 479 |
+
if args.v_pred_like_loss:
|
| 480 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
| 481 |
+
if args.debiased_estimation_loss:
|
| 482 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 483 |
+
|
| 484 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 485 |
+
|
| 486 |
+
accelerator.backward(loss)
|
| 487 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 488 |
+
params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
|
| 489 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 490 |
+
|
| 491 |
+
optimizer.step()
|
| 492 |
+
lr_scheduler.step()
|
| 493 |
+
optimizer.zero_grad(set_to_none=True)
|
| 494 |
+
|
| 495 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 496 |
+
if accelerator.sync_gradients:
|
| 497 |
+
progress_bar.update(1)
|
| 498 |
+
global_step += 1
|
| 499 |
+
|
| 500 |
+
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 501 |
+
|
| 502 |
+
# 指定ステップごとにモデルを保存
|
| 503 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 504 |
+
accelerator.wait_for_everyone()
|
| 505 |
+
if accelerator.is_main_process:
|
| 506 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 507 |
+
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
|
| 508 |
+
|
| 509 |
+
if args.save_state:
|
| 510 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 511 |
+
|
| 512 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 513 |
+
if remove_step_no is not None:
|
| 514 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 515 |
+
remove_model(remove_ckpt_name)
|
| 516 |
+
|
| 517 |
+
current_loss = loss.detach().item()
|
| 518 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 519 |
+
avr_loss: float = loss_recorder.moving_average
|
| 520 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 521 |
+
progress_bar.set_postfix(**logs)
|
| 522 |
+
|
| 523 |
+
if args.logging_dir is not None:
|
| 524 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
| 525 |
+
accelerator.log(logs, step=global_step)
|
| 526 |
+
|
| 527 |
+
if global_step >= args.max_train_steps:
|
| 528 |
+
break
|
| 529 |
+
|
| 530 |
+
if args.logging_dir is not None:
|
| 531 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 532 |
+
accelerator.log(logs, step=epoch + 1)
|
| 533 |
+
|
| 534 |
+
accelerator.wait_for_everyone()
|
| 535 |
+
|
| 536 |
+
# 指定エポックごとにモデルを保存
|
| 537 |
+
if args.save_every_n_epochs is not None:
|
| 538 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 539 |
+
if is_main_process and saving:
|
| 540 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 541 |
+
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
|
| 542 |
+
|
| 543 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 544 |
+
if remove_epoch_no is not None:
|
| 545 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 546 |
+
remove_model(remove_ckpt_name)
|
| 547 |
+
|
| 548 |
+
if args.save_state:
|
| 549 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 550 |
+
|
| 551 |
+
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 552 |
+
|
| 553 |
+
# end of epoch
|
| 554 |
+
|
| 555 |
+
if is_main_process:
|
| 556 |
+
unet = accelerator.unwrap_model(unet)
|
| 557 |
+
|
| 558 |
+
accelerator.end_training()
|
| 559 |
+
|
| 560 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 561 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 562 |
+
|
| 563 |
+
if is_main_process:
|
| 564 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 565 |
+
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
| 566 |
+
|
| 567 |
+
logger.info("model saved.")
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 571 |
+
parser = argparse.ArgumentParser()
|
| 572 |
+
|
| 573 |
+
add_logging_arguments(parser)
|
| 574 |
+
train_util.add_sd_models_arguments(parser)
|
| 575 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
| 576 |
+
train_util.add_training_arguments(parser, False)
|
| 577 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 578 |
+
train_util.add_optimizer_arguments(parser)
|
| 579 |
+
config_util.add_config_arguments(parser)
|
| 580 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 581 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 582 |
+
|
| 583 |
+
parser.add_argument(
|
| 584 |
+
"--save_model_as",
|
| 585 |
+
type=str,
|
| 586 |
+
default="safetensors",
|
| 587 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 588 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
| 589 |
+
)
|
| 590 |
+
parser.add_argument(
|
| 591 |
+
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
|
| 592 |
+
)
|
| 593 |
+
parser.add_argument(
|
| 594 |
+
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
| 595 |
+
)
|
| 596 |
+
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
| 597 |
+
parser.add_argument(
|
| 598 |
+
"--network_dropout",
|
| 599 |
+
type=float,
|
| 600 |
+
default=None,
|
| 601 |
+
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
| 602 |
+
)
|
| 603 |
+
parser.add_argument(
|
| 604 |
+
"--conditioning_data_dir",
|
| 605 |
+
type=str,
|
| 606 |
+
default=None,
|
| 607 |
+
help="conditioning data directory / 条件付けデータのディレクトリ",
|
| 608 |
+
)
|
| 609 |
+
parser.add_argument(
|
| 610 |
+
"--no_half_vae",
|
| 611 |
+
action="store_true",
|
| 612 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 613 |
+
)
|
| 614 |
+
return parser
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
if __name__ == "__main__":
|
| 618 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
| 619 |
+
|
| 620 |
+
parser = setup_parser()
|
| 621 |
+
|
| 622 |
+
args = parser.parse_args()
|
| 623 |
+
train_util.verify_command_line_training_args(args)
|
| 624 |
+
args = train_util.read_config_from_file(args, parser)
|
| 625 |
+
|
| 626 |
+
train(args)
|
sdxl_train_control_net_lllite_old.py
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from multiprocessing import Value
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
import toml
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 15 |
+
init_ipex()
|
| 16 |
+
|
| 17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 18 |
+
from accelerate.utils import set_seed
|
| 19 |
+
from diffusers import DDPMScheduler, ControlNetModel
|
| 20 |
+
from safetensors.torch import load_file
|
| 21 |
+
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
|
| 22 |
+
|
| 23 |
+
import library.model_util as model_util
|
| 24 |
+
import library.train_util as train_util
|
| 25 |
+
import library.config_util as config_util
|
| 26 |
+
from library.config_util import (
|
| 27 |
+
ConfigSanitizer,
|
| 28 |
+
BlueprintGenerator,
|
| 29 |
+
)
|
| 30 |
+
import library.huggingface_util as huggingface_util
|
| 31 |
+
import library.custom_train_functions as custom_train_functions
|
| 32 |
+
from library.custom_train_functions import (
|
| 33 |
+
add_v_prediction_like_loss,
|
| 34 |
+
apply_snr_weight,
|
| 35 |
+
prepare_scheduler_for_custom_training,
|
| 36 |
+
pyramid_noise_like,
|
| 37 |
+
apply_noise_offset,
|
| 38 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 39 |
+
apply_debiased_estimation,
|
| 40 |
+
)
|
| 41 |
+
import networks.control_net_lllite as control_net_lllite
|
| 42 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 43 |
+
|
| 44 |
+
setup_logging()
|
| 45 |
+
import logging
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# TODO 他のスクリプトと共通化する
|
| 51 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
| 52 |
+
logs = {
|
| 53 |
+
"loss/current": current_loss,
|
| 54 |
+
"loss/average": avr_loss,
|
| 55 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
| 59 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
| 60 |
+
|
| 61 |
+
return logs
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def train(args):
|
| 65 |
+
train_util.verify_training_args(args)
|
| 66 |
+
train_util.prepare_dataset_args(args, True)
|
| 67 |
+
sdxl_train_util.verify_sdxl_training_args(args)
|
| 68 |
+
setup_logging(args, reset=True)
|
| 69 |
+
|
| 70 |
+
cache_latents = args.cache_latents
|
| 71 |
+
use_user_config = args.dataset_config is not None
|
| 72 |
+
|
| 73 |
+
if args.seed is None:
|
| 74 |
+
args.seed = random.randint(0, 2**32)
|
| 75 |
+
set_seed(args.seed)
|
| 76 |
+
|
| 77 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
| 78 |
+
|
| 79 |
+
# データセットを準備する
|
| 80 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
| 81 |
+
if use_user_config:
|
| 82 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 83 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 84 |
+
ignored = ["train_data_dir", "conditioning_data_dir"]
|
| 85 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 86 |
+
logger.warning(
|
| 87 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 88 |
+
", ".join(ignored)
|
| 89 |
+
)
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
user_config = {
|
| 93 |
+
"datasets": [
|
| 94 |
+
{
|
| 95 |
+
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
| 96 |
+
args.train_data_dir,
|
| 97 |
+
args.conditioning_data_dir,
|
| 98 |
+
args.caption_extension,
|
| 99 |
+
)
|
| 100 |
+
}
|
| 101 |
+
]
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
| 105 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 106 |
+
|
| 107 |
+
current_epoch = Value("i", 0)
|
| 108 |
+
current_step = Value("i", 0)
|
| 109 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 110 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 111 |
+
|
| 112 |
+
train_dataset_group.verify_bucket_reso_steps(32)
|
| 113 |
+
|
| 114 |
+
if args.debug_dataset:
|
| 115 |
+
train_util.debug_dataset(train_dataset_group)
|
| 116 |
+
return
|
| 117 |
+
if len(train_dataset_group) == 0:
|
| 118 |
+
logger.error(
|
| 119 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
| 120 |
+
)
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
if cache_latents:
|
| 124 |
+
assert (
|
| 125 |
+
train_dataset_group.is_latent_cacheable()
|
| 126 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 127 |
+
else:
|
| 128 |
+
logger.warning(
|
| 129 |
+
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if args.cache_text_encoder_outputs:
|
| 133 |
+
assert (
|
| 134 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
| 135 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
| 136 |
+
|
| 137 |
+
# acceleratorを準備する
|
| 138 |
+
logger.info("prepare accelerator")
|
| 139 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 140 |
+
is_main_process = accelerator.is_main_process
|
| 141 |
+
|
| 142 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 143 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 144 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 145 |
+
|
| 146 |
+
# モデルを読み込む
|
| 147 |
+
(
|
| 148 |
+
load_stable_diffusion_format,
|
| 149 |
+
text_encoder1,
|
| 150 |
+
text_encoder2,
|
| 151 |
+
vae,
|
| 152 |
+
unet,
|
| 153 |
+
logit_scale,
|
| 154 |
+
ckpt_info,
|
| 155 |
+
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
| 156 |
+
|
| 157 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 158 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 159 |
+
|
| 160 |
+
# 学習を準備する
|
| 161 |
+
if cache_latents:
|
| 162 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 163 |
+
vae.requires_grad_(False)
|
| 164 |
+
vae.eval()
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
train_dataset_group.cache_latents(
|
| 167 |
+
vae,
|
| 168 |
+
args.vae_batch_size,
|
| 169 |
+
args.cache_latents_to_disk,
|
| 170 |
+
accelerator.is_main_process,
|
| 171 |
+
)
|
| 172 |
+
vae.to("cpu")
|
| 173 |
+
clean_memory_on_device(accelerator.device)
|
| 174 |
+
|
| 175 |
+
accelerator.wait_for_everyone()
|
| 176 |
+
|
| 177 |
+
# TextEncoderの出力をキャッシュする
|
| 178 |
+
if args.cache_text_encoder_outputs:
|
| 179 |
+
# Text Encodes are eval and no grad
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
train_dataset_group.cache_text_encoder_outputs(
|
| 182 |
+
(tokenizer1, tokenizer2),
|
| 183 |
+
(text_encoder1, text_encoder2),
|
| 184 |
+
accelerator.device,
|
| 185 |
+
None,
|
| 186 |
+
args.cache_text_encoder_outputs_to_disk,
|
| 187 |
+
accelerator.is_main_process,
|
| 188 |
+
)
|
| 189 |
+
accelerator.wait_for_everyone()
|
| 190 |
+
|
| 191 |
+
# prepare ControlNet
|
| 192 |
+
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
|
| 193 |
+
network.apply_to()
|
| 194 |
+
|
| 195 |
+
if args.network_weights is not None:
|
| 196 |
+
info = network.load_weights(args.network_weights)
|
| 197 |
+
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
|
| 198 |
+
|
| 199 |
+
if args.gradient_checkpointing:
|
| 200 |
+
unet.enable_gradient_checkpointing()
|
| 201 |
+
network.enable_gradient_checkpointing() # may have no effect
|
| 202 |
+
|
| 203 |
+
# 学習に必要なクラスを準備する
|
| 204 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 205 |
+
|
| 206 |
+
trainable_params = list(network.prepare_optimizer_params())
|
| 207 |
+
logger.info(f"trainable params count: {len(trainable_params)}")
|
| 208 |
+
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
| 209 |
+
|
| 210 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 211 |
+
|
| 212 |
+
# dataloaderを準備する
|
| 213 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 214 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 215 |
+
|
| 216 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 217 |
+
train_dataset_group,
|
| 218 |
+
batch_size=1,
|
| 219 |
+
shuffle=True,
|
| 220 |
+
collate_fn=collator,
|
| 221 |
+
num_workers=n_workers,
|
| 222 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# 学習ステップ数を計算する
|
| 226 |
+
if args.max_train_epochs is not None:
|
| 227 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 228 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 229 |
+
)
|
| 230 |
+
accelerator.print(
|
| 231 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# データセット側にも学習ステップを送信
|
| 235 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 236 |
+
|
| 237 |
+
# lr schedulerを用意する
|
| 238 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 239 |
+
|
| 240 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
| 241 |
+
if args.full_fp16:
|
| 242 |
+
assert (
|
| 243 |
+
args.mixed_precision == "fp16"
|
| 244 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 245 |
+
accelerator.print("enable full fp16 training.")
|
| 246 |
+
unet.to(weight_dtype)
|
| 247 |
+
network.to(weight_dtype)
|
| 248 |
+
elif args.full_bf16:
|
| 249 |
+
assert (
|
| 250 |
+
args.mixed_precision == "bf16"
|
| 251 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
| 252 |
+
accelerator.print("enable full bf16 training.")
|
| 253 |
+
unet.to(weight_dtype)
|
| 254 |
+
network.to(weight_dtype)
|
| 255 |
+
|
| 256 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 257 |
+
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 258 |
+
unet, network, optimizer, train_dataloader, lr_scheduler
|
| 259 |
+
)
|
| 260 |
+
network: control_net_lllite.ControlNetLLLite
|
| 261 |
+
|
| 262 |
+
if args.gradient_checkpointing:
|
| 263 |
+
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
| 264 |
+
else:
|
| 265 |
+
unet.eval()
|
| 266 |
+
|
| 267 |
+
network.prepare_grad_etc()
|
| 268 |
+
|
| 269 |
+
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
| 270 |
+
if args.cache_text_encoder_outputs:
|
| 271 |
+
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
| 272 |
+
text_encoder1.to("cpu", dtype=torch.float32)
|
| 273 |
+
text_encoder2.to("cpu", dtype=torch.float32)
|
| 274 |
+
clean_memory_on_device(accelerator.device)
|
| 275 |
+
else:
|
| 276 |
+
# make sure Text Encoders are on GPU
|
| 277 |
+
text_encoder1.to(accelerator.device)
|
| 278 |
+
text_encoder2.to(accelerator.device)
|
| 279 |
+
|
| 280 |
+
if not cache_latents:
|
| 281 |
+
vae.requires_grad_(False)
|
| 282 |
+
vae.eval()
|
| 283 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 284 |
+
|
| 285 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 286 |
+
if args.full_fp16:
|
| 287 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 288 |
+
|
| 289 |
+
# resumeする
|
| 290 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 291 |
+
|
| 292 |
+
# epoch数を計算する
|
| 293 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 294 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 295 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 296 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 297 |
+
|
| 298 |
+
# 学習する
|
| 299 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
| 300 |
+
accelerator.print("running training / 学習開始")
|
| 301 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 302 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 303 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 304 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 305 |
+
accelerator.print(
|
| 306 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
| 307 |
+
)
|
| 308 |
+
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
| 309 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 310 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 311 |
+
|
| 312 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 313 |
+
global_step = 0
|
| 314 |
+
|
| 315 |
+
noise_scheduler = DDPMScheduler(
|
| 316 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 317 |
+
)
|
| 318 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 319 |
+
if args.zero_terminal_snr:
|
| 320 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 321 |
+
|
| 322 |
+
if accelerator.is_main_process:
|
| 323 |
+
init_kwargs = {}
|
| 324 |
+
if args.log_tracker_config is not None:
|
| 325 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 326 |
+
accelerator.init_trackers(
|
| 327 |
+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
loss_recorder = train_util.LossRecorder()
|
| 331 |
+
del train_dataset_group
|
| 332 |
+
|
| 333 |
+
# function for saving/removing
|
| 334 |
+
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
| 335 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 336 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 337 |
+
|
| 338 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
| 339 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
|
| 340 |
+
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
|
| 341 |
+
|
| 342 |
+
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
|
| 343 |
+
if args.huggingface_repo_id is not None:
|
| 344 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 345 |
+
|
| 346 |
+
def remove_model(old_ckpt_name):
|
| 347 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 348 |
+
if os.path.exists(old_ckpt_file):
|
| 349 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
| 350 |
+
os.remove(old_ckpt_file)
|
| 351 |
+
|
| 352 |
+
# training loop
|
| 353 |
+
for epoch in range(num_train_epochs):
|
| 354 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 355 |
+
current_epoch.value = epoch + 1
|
| 356 |
+
|
| 357 |
+
network.on_epoch_start() # train()
|
| 358 |
+
|
| 359 |
+
for step, batch in enumerate(train_dataloader):
|
| 360 |
+
current_step.value = global_step
|
| 361 |
+
with accelerator.accumulate(network):
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 364 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 365 |
+
else:
|
| 366 |
+
# latentに変換
|
| 367 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
| 368 |
+
|
| 369 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
| 370 |
+
if torch.any(torch.isnan(latents)):
|
| 371 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
| 372 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
| 373 |
+
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
| 374 |
+
|
| 375 |
+
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
| 376 |
+
input_ids1 = batch["input_ids"]
|
| 377 |
+
input_ids2 = batch["input_ids2"]
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
# Get the text embedding for conditioning
|
| 380 |
+
input_ids1 = input_ids1.to(accelerator.device)
|
| 381 |
+
input_ids2 = input_ids2.to(accelerator.device)
|
| 382 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
| 383 |
+
args.max_token_length,
|
| 384 |
+
input_ids1,
|
| 385 |
+
input_ids2,
|
| 386 |
+
tokenizer1,
|
| 387 |
+
tokenizer2,
|
| 388 |
+
text_encoder1,
|
| 389 |
+
text_encoder2,
|
| 390 |
+
None if not args.full_fp16 else weight_dtype,
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
| 394 |
+
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
| 395 |
+
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
| 396 |
+
|
| 397 |
+
# get size embeddings
|
| 398 |
+
orig_size = batch["original_sizes_hw"]
|
| 399 |
+
crop_size = batch["crop_top_lefts"]
|
| 400 |
+
target_size = batch["target_sizes_hw"]
|
| 401 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
| 402 |
+
|
| 403 |
+
# concat embeddings
|
| 404 |
+
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
| 405 |
+
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
| 406 |
+
|
| 407 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 408 |
+
# with noise offset and/or multires noise if specified
|
| 409 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
| 410 |
+
|
| 411 |
+
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
| 412 |
+
|
| 413 |
+
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
| 414 |
+
|
| 415 |
+
with accelerator.autocast():
|
| 416 |
+
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
|
| 417 |
+
# 内部でcond_embに変換される / it will be converted to cond_emb inside
|
| 418 |
+
network.set_cond_image(controlnet_image)
|
| 419 |
+
|
| 420 |
+
# それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
|
| 421 |
+
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
| 422 |
+
|
| 423 |
+
if args.v_parameterization:
|
| 424 |
+
# v-parameterization training
|
| 425 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 426 |
+
else:
|
| 427 |
+
target = noise
|
| 428 |
+
|
| 429 |
+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
| 430 |
+
loss = loss.mean([1, 2, 3])
|
| 431 |
+
|
| 432 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 433 |
+
loss = loss * loss_weights
|
| 434 |
+
|
| 435 |
+
if args.min_snr_gamma:
|
| 436 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 437 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 438 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 439 |
+
if args.v_pred_like_loss:
|
| 440 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
| 441 |
+
if args.debiased_estimation_loss:
|
| 442 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 443 |
+
|
| 444 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 445 |
+
|
| 446 |
+
accelerator.backward(loss)
|
| 447 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 448 |
+
params_to_clip = network.get_trainable_params()
|
| 449 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 450 |
+
|
| 451 |
+
optimizer.step()
|
| 452 |
+
lr_scheduler.step()
|
| 453 |
+
optimizer.zero_grad(set_to_none=True)
|
| 454 |
+
|
| 455 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 456 |
+
if accelerator.sync_gradients:
|
| 457 |
+
progress_bar.update(1)
|
| 458 |
+
global_step += 1
|
| 459 |
+
|
| 460 |
+
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 461 |
+
|
| 462 |
+
# 指定ステップごとにモデルを保存
|
| 463 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 464 |
+
accelerator.wait_for_everyone()
|
| 465 |
+
if accelerator.is_main_process:
|
| 466 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 467 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
| 468 |
+
|
| 469 |
+
if args.save_state:
|
| 470 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 471 |
+
|
| 472 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 473 |
+
if remove_step_no is not None:
|
| 474 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 475 |
+
remove_model(remove_ckpt_name)
|
| 476 |
+
|
| 477 |
+
current_loss = loss.detach().item()
|
| 478 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 479 |
+
avr_loss: float = loss_recorder.moving_average
|
| 480 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 481 |
+
progress_bar.set_postfix(**logs)
|
| 482 |
+
|
| 483 |
+
if args.logging_dir is not None:
|
| 484 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
| 485 |
+
accelerator.log(logs, step=global_step)
|
| 486 |
+
|
| 487 |
+
if global_step >= args.max_train_steps:
|
| 488 |
+
break
|
| 489 |
+
|
| 490 |
+
if args.logging_dir is not None:
|
| 491 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 492 |
+
accelerator.log(logs, step=epoch + 1)
|
| 493 |
+
|
| 494 |
+
accelerator.wait_for_everyone()
|
| 495 |
+
|
| 496 |
+
# 指定エポックごとにモデルを保存
|
| 497 |
+
if args.save_every_n_epochs is not None:
|
| 498 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 499 |
+
if is_main_process and saving:
|
| 500 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 501 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
| 502 |
+
|
| 503 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 504 |
+
if remove_epoch_no is not None:
|
| 505 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 506 |
+
remove_model(remove_ckpt_name)
|
| 507 |
+
|
| 508 |
+
if args.save_state:
|
| 509 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 510 |
+
|
| 511 |
+
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 512 |
+
|
| 513 |
+
# end of epoch
|
| 514 |
+
|
| 515 |
+
if is_main_process:
|
| 516 |
+
network = accelerator.unwrap_model(network)
|
| 517 |
+
|
| 518 |
+
accelerator.end_training()
|
| 519 |
+
|
| 520 |
+
if is_main_process and args.save_state:
|
| 521 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 522 |
+
|
| 523 |
+
if is_main_process:
|
| 524 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 525 |
+
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
| 526 |
+
|
| 527 |
+
logger.info("model saved.")
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 531 |
+
parser = argparse.ArgumentParser()
|
| 532 |
+
|
| 533 |
+
add_logging_arguments(parser)
|
| 534 |
+
train_util.add_sd_models_arguments(parser)
|
| 535 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
| 536 |
+
train_util.add_training_arguments(parser, False)
|
| 537 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 538 |
+
train_util.add_optimizer_arguments(parser)
|
| 539 |
+
config_util.add_config_arguments(parser)
|
| 540 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 541 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 542 |
+
|
| 543 |
+
parser.add_argument(
|
| 544 |
+
"--save_model_as",
|
| 545 |
+
type=str,
|
| 546 |
+
default="safetensors",
|
| 547 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 548 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
| 549 |
+
)
|
| 550 |
+
parser.add_argument(
|
| 551 |
+
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
|
| 552 |
+
)
|
| 553 |
+
parser.add_argument(
|
| 554 |
+
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
| 557 |
+
parser.add_argument(
|
| 558 |
+
"--network_dropout",
|
| 559 |
+
type=float,
|
| 560 |
+
default=None,
|
| 561 |
+
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
| 562 |
+
)
|
| 563 |
+
parser.add_argument(
|
| 564 |
+
"--conditioning_data_dir",
|
| 565 |
+
type=str,
|
| 566 |
+
default=None,
|
| 567 |
+
help="conditioning data directory / 条件付けデータのディレクトリ",
|
| 568 |
+
)
|
| 569 |
+
parser.add_argument(
|
| 570 |
+
"--no_half_vae",
|
| 571 |
+
action="store_true",
|
| 572 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 573 |
+
)
|
| 574 |
+
return parser
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
if __name__ == "__main__":
|
| 578 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
| 579 |
+
|
| 580 |
+
parser = setup_parser()
|
| 581 |
+
|
| 582 |
+
args = parser.parse_args()
|
| 583 |
+
train_util.verify_command_line_training_args(args)
|
| 584 |
+
args = train_util.read_config_from_file(args, parser)
|
| 585 |
+
|
| 586 |
+
train(args)
|
sdxl_train_network.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 5 |
+
init_ipex()
|
| 6 |
+
|
| 7 |
+
from library import sdxl_model_util, sdxl_train_util, train_util
|
| 8 |
+
import train_network
|
| 9 |
+
from library.utils import setup_logging
|
| 10 |
+
setup_logging()
|
| 11 |
+
import logging
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
| 18 |
+
self.is_sdxl = True
|
| 19 |
+
|
| 20 |
+
def assert_extra_args(self, args, train_dataset_group):
|
| 21 |
+
sdxl_train_util.verify_sdxl_training_args(args)
|
| 22 |
+
|
| 23 |
+
if args.cache_text_encoder_outputs:
|
| 24 |
+
assert (
|
| 25 |
+
train_dataset_group.is_text_encoder_output_cacheable()
|
| 26 |
+
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
| 27 |
+
|
| 28 |
+
assert (
|
| 29 |
+
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
| 30 |
+
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
| 31 |
+
|
| 32 |
+
train_dataset_group.verify_bucket_reso_steps(32)
|
| 33 |
+
|
| 34 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
| 35 |
+
(
|
| 36 |
+
load_stable_diffusion_format,
|
| 37 |
+
text_encoder1,
|
| 38 |
+
text_encoder2,
|
| 39 |
+
vae,
|
| 40 |
+
unet,
|
| 41 |
+
logit_scale,
|
| 42 |
+
ckpt_info,
|
| 43 |
+
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
| 44 |
+
|
| 45 |
+
self.load_stable_diffusion_format = load_stable_diffusion_format
|
| 46 |
+
self.logit_scale = logit_scale
|
| 47 |
+
self.ckpt_info = ckpt_info
|
| 48 |
+
|
| 49 |
+
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
| 50 |
+
|
| 51 |
+
def load_tokenizer(self, args):
|
| 52 |
+
tokenizer = sdxl_train_util.load_tokenizers(args)
|
| 53 |
+
return tokenizer
|
| 54 |
+
|
| 55 |
+
def is_text_encoder_outputs_cached(self, args):
|
| 56 |
+
return args.cache_text_encoder_outputs
|
| 57 |
+
|
| 58 |
+
def cache_text_encoder_outputs_if_needed(
|
| 59 |
+
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
| 60 |
+
):
|
| 61 |
+
if args.cache_text_encoder_outputs:
|
| 62 |
+
if not args.lowram:
|
| 63 |
+
# メモリ消費を減らす
|
| 64 |
+
logger.info("move vae and unet to cpu to save memory")
|
| 65 |
+
org_vae_device = vae.device
|
| 66 |
+
org_unet_device = unet.device
|
| 67 |
+
vae.to("cpu")
|
| 68 |
+
unet.to("cpu")
|
| 69 |
+
clean_memory_on_device(accelerator.device)
|
| 70 |
+
|
| 71 |
+
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
| 72 |
+
with accelerator.autocast():
|
| 73 |
+
dataset.cache_text_encoder_outputs(
|
| 74 |
+
tokenizers,
|
| 75 |
+
text_encoders,
|
| 76 |
+
accelerator.device,
|
| 77 |
+
weight_dtype,
|
| 78 |
+
args.cache_text_encoder_outputs_to_disk,
|
| 79 |
+
accelerator.is_main_process,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
| 83 |
+
text_encoders[1].to("cpu", dtype=torch.float32)
|
| 84 |
+
clean_memory_on_device(accelerator.device)
|
| 85 |
+
|
| 86 |
+
if not args.lowram:
|
| 87 |
+
logger.info("move vae and unet back to original device")
|
| 88 |
+
vae.to(org_vae_device)
|
| 89 |
+
unet.to(org_unet_device)
|
| 90 |
+
else:
|
| 91 |
+
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
| 92 |
+
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
| 93 |
+
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
| 94 |
+
|
| 95 |
+
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
| 96 |
+
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
| 97 |
+
input_ids1 = batch["input_ids"]
|
| 98 |
+
input_ids2 = batch["input_ids2"]
|
| 99 |
+
with torch.enable_grad():
|
| 100 |
+
# Get the text embedding for conditioning
|
| 101 |
+
# TODO support weighted captions
|
| 102 |
+
# if args.weighted_captions:
|
| 103 |
+
# encoder_hidden_states = get_weighted_text_embeddings(
|
| 104 |
+
# tokenizer,
|
| 105 |
+
# text_encoder,
|
| 106 |
+
# batch["captions"],
|
| 107 |
+
# accelerator.device,
|
| 108 |
+
# args.max_token_length // 75 if args.max_token_length else 1,
|
| 109 |
+
# clip_skip=args.clip_skip,
|
| 110 |
+
# )
|
| 111 |
+
# else:
|
| 112 |
+
input_ids1 = input_ids1.to(accelerator.device)
|
| 113 |
+
input_ids2 = input_ids2.to(accelerator.device)
|
| 114 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
| 115 |
+
args.max_token_length,
|
| 116 |
+
input_ids1,
|
| 117 |
+
input_ids2,
|
| 118 |
+
tokenizers[0],
|
| 119 |
+
tokenizers[1],
|
| 120 |
+
text_encoders[0],
|
| 121 |
+
text_encoders[1],
|
| 122 |
+
None if not args.full_fp16 else weight_dtype,
|
| 123 |
+
accelerator=accelerator,
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
| 127 |
+
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
|
| 128 |
+
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
| 129 |
+
|
| 130 |
+
# # verify that the text encoder outputs are correct
|
| 131 |
+
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
|
| 132 |
+
# args.max_token_length,
|
| 133 |
+
# batch["input_ids"].to(text_encoders[0].device),
|
| 134 |
+
# batch["input_ids2"].to(text_encoders[0].device),
|
| 135 |
+
# tokenizers[0],
|
| 136 |
+
# tokenizers[1],
|
| 137 |
+
# text_encoders[0],
|
| 138 |
+
# text_encoders[1],
|
| 139 |
+
# None if not args.full_fp16 else weight_dtype,
|
| 140 |
+
# )
|
| 141 |
+
# b_size = encoder_hidden_states1.shape[0]
|
| 142 |
+
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 143 |
+
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 144 |
+
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
| 145 |
+
# logger.info("text encoder outputs verified")
|
| 146 |
+
|
| 147 |
+
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
| 148 |
+
|
| 149 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
| 150 |
+
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
| 151 |
+
|
| 152 |
+
# get size embeddings
|
| 153 |
+
orig_size = batch["original_sizes_hw"]
|
| 154 |
+
crop_size = batch["crop_top_lefts"]
|
| 155 |
+
target_size = batch["target_sizes_hw"]
|
| 156 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
| 157 |
+
|
| 158 |
+
# concat embeddings
|
| 159 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
| 160 |
+
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
| 161 |
+
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
| 162 |
+
|
| 163 |
+
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
| 164 |
+
return noise_pred
|
| 165 |
+
|
| 166 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
| 167 |
+
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 171 |
+
parser = train_network.setup_parser()
|
| 172 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 173 |
+
return parser
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
parser = setup_parser()
|
| 178 |
+
|
| 179 |
+
args = parser.parse_args()
|
| 180 |
+
train_util.verify_command_line_training_args(args)
|
| 181 |
+
args = train_util.read_config_from_file(args, parser)
|
| 182 |
+
|
| 183 |
+
trainer = SdxlNetworkTrainer()
|
| 184 |
+
trainer.train(args)
|
sdxl_train_textual_inversion.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import regex
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from library.device_utils import init_ipex
|
| 8 |
+
init_ipex()
|
| 9 |
+
|
| 10 |
+
from library import sdxl_model_util, sdxl_train_util, train_util
|
| 11 |
+
|
| 12 |
+
import train_textual_inversion
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
|
| 16 |
+
def __init__(self):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
| 19 |
+
self.is_sdxl = True
|
| 20 |
+
|
| 21 |
+
def assert_extra_args(self, args, train_dataset_group):
|
| 22 |
+
super().assert_extra_args(args, train_dataset_group)
|
| 23 |
+
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
|
| 24 |
+
|
| 25 |
+
train_dataset_group.verify_bucket_reso_steps(32)
|
| 26 |
+
|
| 27 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
| 28 |
+
(
|
| 29 |
+
load_stable_diffusion_format,
|
| 30 |
+
text_encoder1,
|
| 31 |
+
text_encoder2,
|
| 32 |
+
vae,
|
| 33 |
+
unet,
|
| 34 |
+
logit_scale,
|
| 35 |
+
ckpt_info,
|
| 36 |
+
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
| 37 |
+
|
| 38 |
+
self.load_stable_diffusion_format = load_stable_diffusion_format
|
| 39 |
+
self.logit_scale = logit_scale
|
| 40 |
+
self.ckpt_info = ckpt_info
|
| 41 |
+
|
| 42 |
+
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
| 43 |
+
|
| 44 |
+
def load_tokenizer(self, args):
|
| 45 |
+
tokenizer = sdxl_train_util.load_tokenizers(args)
|
| 46 |
+
return tokenizer
|
| 47 |
+
|
| 48 |
+
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
| 49 |
+
input_ids1 = batch["input_ids"]
|
| 50 |
+
input_ids2 = batch["input_ids2"]
|
| 51 |
+
with torch.enable_grad():
|
| 52 |
+
input_ids1 = input_ids1.to(accelerator.device)
|
| 53 |
+
input_ids2 = input_ids2.to(accelerator.device)
|
| 54 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
| 55 |
+
args.max_token_length,
|
| 56 |
+
input_ids1,
|
| 57 |
+
input_ids2,
|
| 58 |
+
tokenizers[0],
|
| 59 |
+
tokenizers[1],
|
| 60 |
+
text_encoders[0],
|
| 61 |
+
text_encoders[1],
|
| 62 |
+
None if not args.full_fp16 else weight_dtype,
|
| 63 |
+
accelerator=accelerator,
|
| 64 |
+
)
|
| 65 |
+
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
| 66 |
+
|
| 67 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
| 68 |
+
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
| 69 |
+
|
| 70 |
+
# get size embeddings
|
| 71 |
+
orig_size = batch["original_sizes_hw"]
|
| 72 |
+
crop_size = batch["crop_top_lefts"]
|
| 73 |
+
target_size = batch["target_sizes_hw"]
|
| 74 |
+
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
| 75 |
+
|
| 76 |
+
# concat embeddings
|
| 77 |
+
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
| 78 |
+
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
| 79 |
+
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
| 80 |
+
|
| 81 |
+
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
| 82 |
+
return noise_pred
|
| 83 |
+
|
| 84 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
| 85 |
+
sdxl_train_util.sample_images(
|
| 86 |
+
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
| 90 |
+
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
|
| 91 |
+
|
| 92 |
+
if save_dtype is not None:
|
| 93 |
+
for key in list(state_dict.keys()):
|
| 94 |
+
v = state_dict[key]
|
| 95 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 96 |
+
state_dict[key] = v
|
| 97 |
+
|
| 98 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 99 |
+
from safetensors.torch import save_file
|
| 100 |
+
|
| 101 |
+
save_file(state_dict, file, metadata)
|
| 102 |
+
else:
|
| 103 |
+
torch.save(state_dict, file)
|
| 104 |
+
|
| 105 |
+
def load_weights(self, file):
|
| 106 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 107 |
+
from safetensors.torch import load_file
|
| 108 |
+
|
| 109 |
+
data = load_file(file)
|
| 110 |
+
else:
|
| 111 |
+
data = torch.load(file, map_location="cpu")
|
| 112 |
+
|
| 113 |
+
emb_l = data.get("clip_l", None) # ViT-L text encoder 1
|
| 114 |
+
emb_g = data.get("clip_g", None) # BiG-G text encoder 2
|
| 115 |
+
|
| 116 |
+
assert (
|
| 117 |
+
emb_l is not None or emb_g is not None
|
| 118 |
+
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
|
| 119 |
+
|
| 120 |
+
return [emb_l, emb_g]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 124 |
+
parser = train_textual_inversion.setup_parser()
|
| 125 |
+
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
|
| 126 |
+
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 127 |
+
return parser
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = setup_parser()
|
| 132 |
+
|
| 133 |
+
args = parser.parse_args()
|
| 134 |
+
train_util.verify_command_line_training_args(args)
|
| 135 |
+
args = train_util.read_config_from_file(args, parser)
|
| 136 |
+
|
| 137 |
+
trainer = SdxlTextualInversionTrainer()
|
| 138 |
+
trainer.train(args)
|
setup.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(name = "library", packages = find_packages())
|
train_controlnet.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import time
|
| 7 |
+
from multiprocessing import Value
|
| 8 |
+
|
| 9 |
+
# from omegaconf import OmegaConf
|
| 10 |
+
import toml
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from library import deepspeed_utils
|
| 16 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 17 |
+
|
| 18 |
+
init_ipex()
|
| 19 |
+
|
| 20 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 21 |
+
from accelerate.utils import set_seed
|
| 22 |
+
from diffusers import DDPMScheduler, ControlNetModel
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
|
| 25 |
+
import library.model_util as model_util
|
| 26 |
+
import library.train_util as train_util
|
| 27 |
+
import library.config_util as config_util
|
| 28 |
+
from library.config_util import (
|
| 29 |
+
ConfigSanitizer,
|
| 30 |
+
BlueprintGenerator,
|
| 31 |
+
)
|
| 32 |
+
import library.huggingface_util as huggingface_util
|
| 33 |
+
import library.custom_train_functions as custom_train_functions
|
| 34 |
+
from library.custom_train_functions import (
|
| 35 |
+
apply_snr_weight,
|
| 36 |
+
pyramid_noise_like,
|
| 37 |
+
apply_noise_offset,
|
| 38 |
+
)
|
| 39 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 40 |
+
|
| 41 |
+
setup_logging()
|
| 42 |
+
import logging
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# TODO 他のスクリプトと共通化する
|
| 48 |
+
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
| 49 |
+
logs = {
|
| 50 |
+
"loss/current": current_loss,
|
| 51 |
+
"loss/average": avr_loss,
|
| 52 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()):
|
| 56 |
+
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
| 57 |
+
|
| 58 |
+
return logs
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def train(args):
|
| 62 |
+
# session_id = random.randint(0, 2**32)
|
| 63 |
+
# training_started_at = time.time()
|
| 64 |
+
train_util.verify_training_args(args)
|
| 65 |
+
train_util.prepare_dataset_args(args, True)
|
| 66 |
+
setup_logging(args, reset=True)
|
| 67 |
+
|
| 68 |
+
cache_latents = args.cache_latents
|
| 69 |
+
use_user_config = args.dataset_config is not None
|
| 70 |
+
|
| 71 |
+
if args.seed is None:
|
| 72 |
+
args.seed = random.randint(0, 2**32)
|
| 73 |
+
set_seed(args.seed)
|
| 74 |
+
|
| 75 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 76 |
+
|
| 77 |
+
# データセットを準備する
|
| 78 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
| 79 |
+
if use_user_config:
|
| 80 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 81 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 82 |
+
ignored = ["train_data_dir", "conditioning_data_dir"]
|
| 83 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 84 |
+
logger.warning(
|
| 85 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 86 |
+
", ".join(ignored)
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
user_config = {
|
| 91 |
+
"datasets": [
|
| 92 |
+
{
|
| 93 |
+
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
|
| 94 |
+
args.train_data_dir,
|
| 95 |
+
args.conditioning_data_dir,
|
| 96 |
+
args.caption_extension,
|
| 97 |
+
)
|
| 98 |
+
}
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 103 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 104 |
+
|
| 105 |
+
current_epoch = Value("i", 0)
|
| 106 |
+
current_step = Value("i", 0)
|
| 107 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 108 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 109 |
+
|
| 110 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
| 111 |
+
|
| 112 |
+
if args.debug_dataset:
|
| 113 |
+
train_util.debug_dataset(train_dataset_group)
|
| 114 |
+
return
|
| 115 |
+
if len(train_dataset_group) == 0:
|
| 116 |
+
logger.error(
|
| 117 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
| 118 |
+
)
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
if cache_latents:
|
| 122 |
+
assert (
|
| 123 |
+
train_dataset_group.is_latent_cacheable()
|
| 124 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 125 |
+
|
| 126 |
+
# acceleratorを準備する
|
| 127 |
+
logger.info("prepare accelerator")
|
| 128 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 129 |
+
is_main_process = accelerator.is_main_process
|
| 130 |
+
|
| 131 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 132 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 133 |
+
|
| 134 |
+
# モデルを読み込む
|
| 135 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(
|
| 136 |
+
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# DiffusersのControlNetが使用するデータを準備する
|
| 140 |
+
if args.v2:
|
| 141 |
+
unet.config = {
|
| 142 |
+
"act_fn": "silu",
|
| 143 |
+
"attention_head_dim": [5, 10, 20, 20],
|
| 144 |
+
"block_out_channels": [320, 640, 1280, 1280],
|
| 145 |
+
"center_input_sample": False,
|
| 146 |
+
"cross_attention_dim": 1024,
|
| 147 |
+
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
|
| 148 |
+
"downsample_padding": 1,
|
| 149 |
+
"dual_cross_attention": False,
|
| 150 |
+
"flip_sin_to_cos": True,
|
| 151 |
+
"freq_shift": 0,
|
| 152 |
+
"in_channels": 4,
|
| 153 |
+
"layers_per_block": 2,
|
| 154 |
+
"mid_block_scale_factor": 1,
|
| 155 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 156 |
+
"norm_eps": 1e-05,
|
| 157 |
+
"norm_num_groups": 32,
|
| 158 |
+
"num_attention_heads": [5, 10, 20, 20],
|
| 159 |
+
"num_class_embeds": None,
|
| 160 |
+
"only_cross_attention": False,
|
| 161 |
+
"out_channels": 4,
|
| 162 |
+
"sample_size": 96,
|
| 163 |
+
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
| 164 |
+
"use_linear_projection": True,
|
| 165 |
+
"upcast_attention": True,
|
| 166 |
+
"only_cross_attention": False,
|
| 167 |
+
"downsample_padding": 1,
|
| 168 |
+
"use_linear_projection": True,
|
| 169 |
+
"class_embed_type": None,
|
| 170 |
+
"num_class_embeds": None,
|
| 171 |
+
"resnet_time_scale_shift": "default",
|
| 172 |
+
"projection_class_embeddings_input_dim": None,
|
| 173 |
+
}
|
| 174 |
+
else:
|
| 175 |
+
unet.config = {
|
| 176 |
+
"act_fn": "silu",
|
| 177 |
+
"attention_head_dim": 8,
|
| 178 |
+
"block_out_channels": [320, 640, 1280, 1280],
|
| 179 |
+
"center_input_sample": False,
|
| 180 |
+
"cross_attention_dim": 768,
|
| 181 |
+
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
|
| 182 |
+
"downsample_padding": 1,
|
| 183 |
+
"flip_sin_to_cos": True,
|
| 184 |
+
"freq_shift": 0,
|
| 185 |
+
"in_channels": 4,
|
| 186 |
+
"layers_per_block": 2,
|
| 187 |
+
"mid_block_scale_factor": 1,
|
| 188 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 189 |
+
"norm_eps": 1e-05,
|
| 190 |
+
"norm_num_groups": 32,
|
| 191 |
+
"num_attention_heads": 8,
|
| 192 |
+
"out_channels": 4,
|
| 193 |
+
"sample_size": 64,
|
| 194 |
+
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
| 195 |
+
"only_cross_attention": False,
|
| 196 |
+
"downsample_padding": 1,
|
| 197 |
+
"use_linear_projection": False,
|
| 198 |
+
"class_embed_type": None,
|
| 199 |
+
"num_class_embeds": None,
|
| 200 |
+
"upcast_attention": False,
|
| 201 |
+
"resnet_time_scale_shift": "default",
|
| 202 |
+
"projection_class_embeddings_input_dim": None,
|
| 203 |
+
}
|
| 204 |
+
# unet.config = OmegaConf.create(unet.config)
|
| 205 |
+
|
| 206 |
+
# make unet.config iterable and accessible by attribute
|
| 207 |
+
class CustomConfig:
|
| 208 |
+
def __init__(self, **kwargs):
|
| 209 |
+
self.__dict__.update(kwargs)
|
| 210 |
+
|
| 211 |
+
def __getattr__(self, name):
|
| 212 |
+
if name in self.__dict__:
|
| 213 |
+
return self.__dict__[name]
|
| 214 |
+
else:
|
| 215 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
| 216 |
+
|
| 217 |
+
def __contains__(self, name):
|
| 218 |
+
return name in self.__dict__
|
| 219 |
+
|
| 220 |
+
unet.config = CustomConfig(**unet.config)
|
| 221 |
+
|
| 222 |
+
controlnet = ControlNetModel.from_unet(unet)
|
| 223 |
+
|
| 224 |
+
if args.controlnet_model_name_or_path:
|
| 225 |
+
filename = args.controlnet_model_name_or_path
|
| 226 |
+
if os.path.isfile(filename):
|
| 227 |
+
if os.path.splitext(filename)[1] == ".safetensors":
|
| 228 |
+
state_dict = load_file(filename)
|
| 229 |
+
else:
|
| 230 |
+
state_dict = torch.load(filename)
|
| 231 |
+
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
|
| 232 |
+
controlnet.load_state_dict(state_dict)
|
| 233 |
+
elif os.path.isdir(filename):
|
| 234 |
+
controlnet = ControlNetModel.from_pretrained(filename)
|
| 235 |
+
|
| 236 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 237 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 238 |
+
|
| 239 |
+
# 学習を準備する
|
| 240 |
+
if cache_latents:
|
| 241 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 242 |
+
vae.requires_grad_(False)
|
| 243 |
+
vae.eval()
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
train_dataset_group.cache_latents(
|
| 246 |
+
vae,
|
| 247 |
+
args.vae_batch_size,
|
| 248 |
+
args.cache_latents_to_disk,
|
| 249 |
+
accelerator.is_main_process,
|
| 250 |
+
)
|
| 251 |
+
vae.to("cpu")
|
| 252 |
+
clean_memory_on_device(accelerator.device)
|
| 253 |
+
|
| 254 |
+
accelerator.wait_for_everyone()
|
| 255 |
+
|
| 256 |
+
if args.gradient_checkpointing:
|
| 257 |
+
controlnet.enable_gradient_checkpointing()
|
| 258 |
+
|
| 259 |
+
# 学習に必要なクラスを準備する
|
| 260 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 261 |
+
|
| 262 |
+
trainable_params = list(controlnet.parameters())
|
| 263 |
+
|
| 264 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 265 |
+
|
| 266 |
+
# dataloaderを準備する
|
| 267 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 268 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 269 |
+
|
| 270 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 271 |
+
train_dataset_group,
|
| 272 |
+
batch_size=1,
|
| 273 |
+
shuffle=True,
|
| 274 |
+
collate_fn=collator,
|
| 275 |
+
num_workers=n_workers,
|
| 276 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# 学習ステップ数を計算する
|
| 280 |
+
if args.max_train_epochs is not None:
|
| 281 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 282 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 283 |
+
)
|
| 284 |
+
accelerator.print(
|
| 285 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# データセット側にも学習ステップを送信
|
| 289 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 290 |
+
|
| 291 |
+
# lr schedulerを用意する
|
| 292 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 293 |
+
|
| 294 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 295 |
+
if args.full_fp16:
|
| 296 |
+
assert (
|
| 297 |
+
args.mixed_precision == "fp16"
|
| 298 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 299 |
+
accelerator.print("enable full fp16 training.")
|
| 300 |
+
controlnet.to(weight_dtype)
|
| 301 |
+
|
| 302 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 303 |
+
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 304 |
+
controlnet, optimizer, train_dataloader, lr_scheduler
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
unet.requires_grad_(False)
|
| 308 |
+
text_encoder.requires_grad_(False)
|
| 309 |
+
unet.to(accelerator.device)
|
| 310 |
+
text_encoder.to(accelerator.device)
|
| 311 |
+
|
| 312 |
+
# transform DDP after prepare
|
| 313 |
+
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
|
| 314 |
+
|
| 315 |
+
controlnet.train()
|
| 316 |
+
|
| 317 |
+
if not cache_latents:
|
| 318 |
+
vae.requires_grad_(False)
|
| 319 |
+
vae.eval()
|
| 320 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 321 |
+
|
| 322 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 323 |
+
if args.full_fp16:
|
| 324 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 325 |
+
|
| 326 |
+
# resumeする
|
| 327 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 328 |
+
|
| 329 |
+
# epoch数を計算する
|
| 330 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 331 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 332 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 333 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 334 |
+
|
| 335 |
+
# 学習する
|
| 336 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
| 337 |
+
accelerator.print("running training / 学習開始")
|
| 338 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 339 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 340 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 341 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 342 |
+
accelerator.print(
|
| 343 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
| 344 |
+
)
|
| 345 |
+
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
| 346 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 347 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 348 |
+
|
| 349 |
+
progress_bar = tqdm(
|
| 350 |
+
range(args.max_train_steps),
|
| 351 |
+
smoothing=0,
|
| 352 |
+
disable=not accelerator.is_local_main_process,
|
| 353 |
+
desc="steps",
|
| 354 |
+
)
|
| 355 |
+
global_step = 0
|
| 356 |
+
|
| 357 |
+
noise_scheduler = DDPMScheduler(
|
| 358 |
+
beta_start=0.00085,
|
| 359 |
+
beta_end=0.012,
|
| 360 |
+
beta_schedule="scaled_linear",
|
| 361 |
+
num_train_timesteps=1000,
|
| 362 |
+
clip_sample=False,
|
| 363 |
+
)
|
| 364 |
+
if accelerator.is_main_process:
|
| 365 |
+
init_kwargs = {}
|
| 366 |
+
if args.wandb_run_name:
|
| 367 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 368 |
+
if args.log_tracker_config is not None:
|
| 369 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 370 |
+
accelerator.init_trackers(
|
| 371 |
+
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
|
| 372 |
+
config=train_util.get_sanitized_config_or_none(args),
|
| 373 |
+
init_kwargs=init_kwargs,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
loss_recorder = train_util.LossRecorder()
|
| 377 |
+
del train_dataset_group
|
| 378 |
+
|
| 379 |
+
# function for saving/removing
|
| 380 |
+
def save_model(ckpt_name, model, force_sync_upload=False):
|
| 381 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 382 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 383 |
+
|
| 384 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
| 385 |
+
|
| 386 |
+
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
|
| 387 |
+
|
| 388 |
+
if save_dtype is not None:
|
| 389 |
+
for key in list(state_dict.keys()):
|
| 390 |
+
v = state_dict[key]
|
| 391 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 392 |
+
state_dict[key] = v
|
| 393 |
+
|
| 394 |
+
if os.path.splitext(ckpt_file)[1] == ".safetensors":
|
| 395 |
+
from safetensors.torch import save_file
|
| 396 |
+
|
| 397 |
+
save_file(state_dict, ckpt_file)
|
| 398 |
+
else:
|
| 399 |
+
torch.save(state_dict, ckpt_file)
|
| 400 |
+
|
| 401 |
+
if args.huggingface_repo_id is not None:
|
| 402 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 403 |
+
|
| 404 |
+
def remove_model(old_ckpt_name):
|
| 405 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 406 |
+
if os.path.exists(old_ckpt_file):
|
| 407 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
| 408 |
+
os.remove(old_ckpt_file)
|
| 409 |
+
|
| 410 |
+
# For --sample_at_first
|
| 411 |
+
train_util.sample_images(
|
| 412 |
+
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# training loop
|
| 416 |
+
for epoch in range(num_train_epochs):
|
| 417 |
+
if is_main_process:
|
| 418 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 419 |
+
current_epoch.value = epoch + 1
|
| 420 |
+
|
| 421 |
+
for step, batch in enumerate(train_dataloader):
|
| 422 |
+
current_step.value = global_step
|
| 423 |
+
with accelerator.accumulate(controlnet):
|
| 424 |
+
with torch.no_grad():
|
| 425 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 426 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 427 |
+
else:
|
| 428 |
+
# latentに変換
|
| 429 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 430 |
+
latents = latents * 0.18215
|
| 431 |
+
b_size = latents.shape[0]
|
| 432 |
+
|
| 433 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 434 |
+
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
| 435 |
+
|
| 436 |
+
# Sample noise that we'll add to the latents
|
| 437 |
+
noise = torch.randn_like(latents, device=latents.device)
|
| 438 |
+
if args.noise_offset:
|
| 439 |
+
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
| 440 |
+
elif args.multires_noise_iterations:
|
| 441 |
+
noise = pyramid_noise_like(
|
| 442 |
+
noise,
|
| 443 |
+
latents.device,
|
| 444 |
+
args.multires_noise_iterations,
|
| 445 |
+
args.multires_noise_discount,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Sample a random timestep for each image
|
| 449 |
+
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
|
| 450 |
+
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 454 |
+
# (this is the forward diffusion process)
|
| 455 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 456 |
+
|
| 457 |
+
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
|
| 458 |
+
|
| 459 |
+
with accelerator.autocast():
|
| 460 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
| 461 |
+
noisy_latents,
|
| 462 |
+
timesteps,
|
| 463 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 464 |
+
controlnet_cond=controlnet_image,
|
| 465 |
+
return_dict=False,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# Predict the noise residual
|
| 469 |
+
noise_pred = unet(
|
| 470 |
+
noisy_latents,
|
| 471 |
+
timesteps,
|
| 472 |
+
encoder_hidden_states,
|
| 473 |
+
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
|
| 474 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
| 475 |
+
).sample
|
| 476 |
+
|
| 477 |
+
if args.v_parameterization:
|
| 478 |
+
# v-parameterization training
|
| 479 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 480 |
+
else:
|
| 481 |
+
target = noise
|
| 482 |
+
|
| 483 |
+
loss = train_util.conditional_loss(
|
| 484 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
| 485 |
+
)
|
| 486 |
+
loss = loss.mean([1, 2, 3])
|
| 487 |
+
|
| 488 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 489 |
+
loss = loss * loss_weights
|
| 490 |
+
|
| 491 |
+
if args.min_snr_gamma:
|
| 492 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 493 |
+
|
| 494 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 495 |
+
|
| 496 |
+
accelerator.backward(loss)
|
| 497 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 498 |
+
params_to_clip = controlnet.parameters()
|
| 499 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 500 |
+
|
| 501 |
+
optimizer.step()
|
| 502 |
+
lr_scheduler.step()
|
| 503 |
+
optimizer.zero_grad(set_to_none=True)
|
| 504 |
+
|
| 505 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 506 |
+
if accelerator.sync_gradients:
|
| 507 |
+
progress_bar.update(1)
|
| 508 |
+
global_step += 1
|
| 509 |
+
|
| 510 |
+
train_util.sample_images(
|
| 511 |
+
accelerator,
|
| 512 |
+
args,
|
| 513 |
+
None,
|
| 514 |
+
global_step,
|
| 515 |
+
accelerator.device,
|
| 516 |
+
vae,
|
| 517 |
+
tokenizer,
|
| 518 |
+
text_encoder,
|
| 519 |
+
unet,
|
| 520 |
+
controlnet=controlnet,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# 指定ステップごとにモデルを保存
|
| 524 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 525 |
+
accelerator.wait_for_everyone()
|
| 526 |
+
if accelerator.is_main_process:
|
| 527 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 528 |
+
save_model(
|
| 529 |
+
ckpt_name,
|
| 530 |
+
accelerator.unwrap_model(controlnet),
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
if args.save_state:
|
| 534 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 535 |
+
|
| 536 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 537 |
+
if remove_step_no is not None:
|
| 538 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 539 |
+
remove_model(remove_ckpt_name)
|
| 540 |
+
|
| 541 |
+
current_loss = loss.detach().item()
|
| 542 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 543 |
+
avr_loss: float = loss_recorder.moving_average
|
| 544 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 545 |
+
progress_bar.set_postfix(**logs)
|
| 546 |
+
|
| 547 |
+
if args.logging_dir is not None:
|
| 548 |
+
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
| 549 |
+
accelerator.log(logs, step=global_step)
|
| 550 |
+
|
| 551 |
+
if global_step >= args.max_train_steps:
|
| 552 |
+
break
|
| 553 |
+
|
| 554 |
+
if args.logging_dir is not None:
|
| 555 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 556 |
+
accelerator.log(logs, step=epoch + 1)
|
| 557 |
+
|
| 558 |
+
accelerator.wait_for_everyone()
|
| 559 |
+
|
| 560 |
+
# 指定エポックごとにモデルを保存
|
| 561 |
+
if args.save_every_n_epochs is not None:
|
| 562 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 563 |
+
if is_main_process and saving:
|
| 564 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 565 |
+
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
|
| 566 |
+
|
| 567 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 568 |
+
if remove_epoch_no is not None:
|
| 569 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 570 |
+
remove_model(remove_ckpt_name)
|
| 571 |
+
|
| 572 |
+
if args.save_state:
|
| 573 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 574 |
+
|
| 575 |
+
train_util.sample_images(
|
| 576 |
+
accelerator,
|
| 577 |
+
args,
|
| 578 |
+
epoch + 1,
|
| 579 |
+
global_step,
|
| 580 |
+
accelerator.device,
|
| 581 |
+
vae,
|
| 582 |
+
tokenizer,
|
| 583 |
+
text_encoder,
|
| 584 |
+
unet,
|
| 585 |
+
controlnet=controlnet,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# end of epoch
|
| 589 |
+
if is_main_process:
|
| 590 |
+
controlnet = accelerator.unwrap_model(controlnet)
|
| 591 |
+
|
| 592 |
+
accelerator.end_training()
|
| 593 |
+
|
| 594 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 595 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 596 |
+
|
| 597 |
+
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
|
| 598 |
+
|
| 599 |
+
if is_main_process:
|
| 600 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 601 |
+
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
| 602 |
+
|
| 603 |
+
logger.info("model saved.")
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 607 |
+
parser = argparse.ArgumentParser()
|
| 608 |
+
|
| 609 |
+
add_logging_arguments(parser)
|
| 610 |
+
train_util.add_sd_models_arguments(parser)
|
| 611 |
+
train_util.add_dataset_arguments(parser, False, True, True)
|
| 612 |
+
train_util.add_training_arguments(parser, False)
|
| 613 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 614 |
+
train_util.add_optimizer_arguments(parser)
|
| 615 |
+
config_util.add_config_arguments(parser)
|
| 616 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 617 |
+
|
| 618 |
+
parser.add_argument(
|
| 619 |
+
"--save_model_as",
|
| 620 |
+
type=str,
|
| 621 |
+
default="safetensors",
|
| 622 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 623 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
| 624 |
+
)
|
| 625 |
+
parser.add_argument(
|
| 626 |
+
"--controlnet_model_name_or_path",
|
| 627 |
+
type=str,
|
| 628 |
+
default=None,
|
| 629 |
+
help="controlnet model name or path / controlnetのモデル名またはパス",
|
| 630 |
+
)
|
| 631 |
+
parser.add_argument(
|
| 632 |
+
"--conditioning_data_dir",
|
| 633 |
+
type=str,
|
| 634 |
+
default=None,
|
| 635 |
+
help="conditioning data directory / 条件付けデータのディレクトリ",
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
return parser
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
if __name__ == "__main__":
|
| 642 |
+
parser = setup_parser()
|
| 643 |
+
|
| 644 |
+
args = parser.parse_args()
|
| 645 |
+
train_util.verify_command_line_training_args(args)
|
| 646 |
+
args = train_util.read_config_from_file(args, parser)
|
| 647 |
+
|
| 648 |
+
train(args)
|
train_db.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DreamBooth training
|
| 2 |
+
# XXX dropped option: fine_tune
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import itertools
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from multiprocessing import Value
|
| 9 |
+
import toml
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from library import deepspeed_utils
|
| 15 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
init_ipex()
|
| 19 |
+
|
| 20 |
+
from accelerate.utils import set_seed
|
| 21 |
+
from diffusers import DDPMScheduler
|
| 22 |
+
|
| 23 |
+
import library.train_util as train_util
|
| 24 |
+
import library.config_util as config_util
|
| 25 |
+
from library.config_util import (
|
| 26 |
+
ConfigSanitizer,
|
| 27 |
+
BlueprintGenerator,
|
| 28 |
+
)
|
| 29 |
+
import library.custom_train_functions as custom_train_functions
|
| 30 |
+
from library.custom_train_functions import (
|
| 31 |
+
apply_snr_weight,
|
| 32 |
+
get_weighted_text_embeddings,
|
| 33 |
+
prepare_scheduler_for_custom_training,
|
| 34 |
+
pyramid_noise_like,
|
| 35 |
+
apply_noise_offset,
|
| 36 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 37 |
+
apply_debiased_estimation,
|
| 38 |
+
apply_masked_loss,
|
| 39 |
+
)
|
| 40 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 41 |
+
|
| 42 |
+
setup_logging()
|
| 43 |
+
import logging
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
# perlin_noise,
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def train(args):
|
| 51 |
+
train_util.verify_training_args(args)
|
| 52 |
+
train_util.prepare_dataset_args(args, False)
|
| 53 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
| 54 |
+
setup_logging(args, reset=True)
|
| 55 |
+
|
| 56 |
+
cache_latents = args.cache_latents
|
| 57 |
+
|
| 58 |
+
if args.seed is not None:
|
| 59 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
| 60 |
+
|
| 61 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 62 |
+
|
| 63 |
+
# データセットを準備する
|
| 64 |
+
if args.dataset_class is None:
|
| 65 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
|
| 66 |
+
if args.dataset_config is not None:
|
| 67 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 68 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 69 |
+
ignored = ["train_data_dir", "reg_data_dir"]
|
| 70 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 71 |
+
logger.warning(
|
| 72 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 73 |
+
", ".join(ignored)
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
user_config = {
|
| 78 |
+
"datasets": [
|
| 79 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 84 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 85 |
+
else:
|
| 86 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
| 87 |
+
|
| 88 |
+
current_epoch = Value("i", 0)
|
| 89 |
+
current_step = Value("i", 0)
|
| 90 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 91 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 92 |
+
|
| 93 |
+
if args.no_token_padding:
|
| 94 |
+
train_dataset_group.disable_token_padding()
|
| 95 |
+
|
| 96 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
| 97 |
+
|
| 98 |
+
if args.debug_dataset:
|
| 99 |
+
train_util.debug_dataset(train_dataset_group)
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
if cache_latents:
|
| 103 |
+
assert (
|
| 104 |
+
train_dataset_group.is_latent_cacheable()
|
| 105 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 106 |
+
|
| 107 |
+
# acceleratorを準備する
|
| 108 |
+
logger.info("prepare accelerator")
|
| 109 |
+
|
| 110 |
+
if args.gradient_accumulation_steps > 1:
|
| 111 |
+
logger.warning(
|
| 112 |
+
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
| 113 |
+
)
|
| 114 |
+
logger.warning(
|
| 115 |
+
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 119 |
+
|
| 120 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 121 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 122 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 123 |
+
|
| 124 |
+
# モデルを読み込む
|
| 125 |
+
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 126 |
+
|
| 127 |
+
# verify load/save model formats
|
| 128 |
+
if load_stable_diffusion_format:
|
| 129 |
+
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
|
| 130 |
+
src_diffusers_model_path = None
|
| 131 |
+
else:
|
| 132 |
+
src_stable_diffusion_ckpt = None
|
| 133 |
+
src_diffusers_model_path = args.pretrained_model_name_or_path
|
| 134 |
+
|
| 135 |
+
if args.save_model_as is None:
|
| 136 |
+
save_stable_diffusion_format = load_stable_diffusion_format
|
| 137 |
+
use_safetensors = args.use_safetensors
|
| 138 |
+
else:
|
| 139 |
+
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
| 140 |
+
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
| 141 |
+
|
| 142 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 143 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 144 |
+
|
| 145 |
+
# 学習を準備する
|
| 146 |
+
if cache_latents:
|
| 147 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 148 |
+
vae.requires_grad_(False)
|
| 149 |
+
vae.eval()
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 152 |
+
vae.to("cpu")
|
| 153 |
+
clean_memory_on_device(accelerator.device)
|
| 154 |
+
|
| 155 |
+
accelerator.wait_for_everyone()
|
| 156 |
+
|
| 157 |
+
# 学習を準備する:モデルを適切な状態にする
|
| 158 |
+
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
| 159 |
+
unet.requires_grad_(True) # 念のため追加
|
| 160 |
+
text_encoder.requires_grad_(train_text_encoder)
|
| 161 |
+
if not train_text_encoder:
|
| 162 |
+
accelerator.print("Text Encoder is not trained.")
|
| 163 |
+
|
| 164 |
+
if args.gradient_checkpointing:
|
| 165 |
+
unet.enable_gradient_checkpointing()
|
| 166 |
+
text_encoder.gradient_checkpointing_enable()
|
| 167 |
+
|
| 168 |
+
if not cache_latents:
|
| 169 |
+
vae.requires_grad_(False)
|
| 170 |
+
vae.eval()
|
| 171 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 172 |
+
|
| 173 |
+
# 学習に必要なクラスを準備する
|
| 174 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 175 |
+
if train_text_encoder:
|
| 176 |
+
if args.learning_rate_te is None:
|
| 177 |
+
# wightout list, adamw8bit is crashed
|
| 178 |
+
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
| 179 |
+
else:
|
| 180 |
+
trainable_params = [
|
| 181 |
+
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
| 182 |
+
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
| 183 |
+
]
|
| 184 |
+
else:
|
| 185 |
+
trainable_params = unet.parameters()
|
| 186 |
+
|
| 187 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 188 |
+
|
| 189 |
+
# dataloaderを準備する
|
| 190 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 191 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 192 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 193 |
+
train_dataset_group,
|
| 194 |
+
batch_size=1,
|
| 195 |
+
shuffle=True,
|
| 196 |
+
collate_fn=collator,
|
| 197 |
+
num_workers=n_workers,
|
| 198 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# 学習ステップ数を計算する
|
| 202 |
+
if args.max_train_epochs is not None:
|
| 203 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 204 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 205 |
+
)
|
| 206 |
+
accelerator.print(
|
| 207 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# データセット側にも学習ステップを送信
|
| 211 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 212 |
+
|
| 213 |
+
if args.stop_text_encoder_training is None:
|
| 214 |
+
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
| 215 |
+
|
| 216 |
+
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
| 217 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 218 |
+
|
| 219 |
+
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
| 220 |
+
if args.full_fp16:
|
| 221 |
+
assert (
|
| 222 |
+
args.mixed_precision == "fp16"
|
| 223 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 224 |
+
accelerator.print("enable full fp16 training.")
|
| 225 |
+
unet.to(weight_dtype)
|
| 226 |
+
text_encoder.to(weight_dtype)
|
| 227 |
+
|
| 228 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 229 |
+
if args.deepspeed:
|
| 230 |
+
if args.train_text_encoder:
|
| 231 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
|
| 232 |
+
else:
|
| 233 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
|
| 234 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 235 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
| 236 |
+
)
|
| 237 |
+
training_models = [ds_model]
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
if train_text_encoder:
|
| 241 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 242 |
+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 243 |
+
)
|
| 244 |
+
training_models = [unet, text_encoder]
|
| 245 |
+
else:
|
| 246 |
+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
| 247 |
+
training_models = [unet]
|
| 248 |
+
|
| 249 |
+
if not train_text_encoder:
|
| 250 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
| 251 |
+
|
| 252 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 253 |
+
if args.full_fp16:
|
| 254 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 255 |
+
|
| 256 |
+
# resumeする
|
| 257 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 258 |
+
|
| 259 |
+
# epoch数を計算する
|
| 260 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 261 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 262 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 263 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 264 |
+
|
| 265 |
+
# 学習する
|
| 266 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 267 |
+
accelerator.print("running training / 学習開始")
|
| 268 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 269 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 270 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 271 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 272 |
+
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
| 273 |
+
accelerator.print(
|
| 274 |
+
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
| 275 |
+
)
|
| 276 |
+
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 277 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 278 |
+
|
| 279 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 280 |
+
global_step = 0
|
| 281 |
+
|
| 282 |
+
noise_scheduler = DDPMScheduler(
|
| 283 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 284 |
+
)
|
| 285 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 286 |
+
if args.zero_terminal_snr:
|
| 287 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 288 |
+
|
| 289 |
+
if accelerator.is_main_process:
|
| 290 |
+
init_kwargs = {}
|
| 291 |
+
if args.wandb_run_name:
|
| 292 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 293 |
+
if args.log_tracker_config is not None:
|
| 294 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 295 |
+
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
|
| 296 |
+
|
| 297 |
+
# For --sample_at_first
|
| 298 |
+
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 299 |
+
|
| 300 |
+
loss_recorder = train_util.LossRecorder()
|
| 301 |
+
for epoch in range(num_train_epochs):
|
| 302 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 303 |
+
current_epoch.value = epoch + 1
|
| 304 |
+
|
| 305 |
+
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
| 306 |
+
unet.train()
|
| 307 |
+
# train==True is required to enable gradient_checkpointing
|
| 308 |
+
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
|
| 309 |
+
text_encoder.train()
|
| 310 |
+
|
| 311 |
+
for step, batch in enumerate(train_dataloader):
|
| 312 |
+
current_step.value = global_step
|
| 313 |
+
# 指定したステップ数でText Encoderの学習を止める
|
| 314 |
+
if global_step == args.stop_text_encoder_training:
|
| 315 |
+
accelerator.print(f"stop text encoder training at step {global_step}")
|
| 316 |
+
if not args.gradient_checkpointing:
|
| 317 |
+
text_encoder.train(False)
|
| 318 |
+
text_encoder.requires_grad_(False)
|
| 319 |
+
if len(training_models) == 2:
|
| 320 |
+
training_models = training_models[0] # remove text_encoder from training_models
|
| 321 |
+
|
| 322 |
+
with accelerator.accumulate(*training_models):
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
# latentに変換
|
| 325 |
+
if cache_latents:
|
| 326 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 327 |
+
else:
|
| 328 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 329 |
+
latents = latents * 0.18215
|
| 330 |
+
b_size = latents.shape[0]
|
| 331 |
+
|
| 332 |
+
# Get the text embedding for conditioning
|
| 333 |
+
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
| 334 |
+
if args.weighted_captions:
|
| 335 |
+
encoder_hidden_states = get_weighted_text_embeddings(
|
| 336 |
+
tokenizer,
|
| 337 |
+
text_encoder,
|
| 338 |
+
batch["captions"],
|
| 339 |
+
accelerator.device,
|
| 340 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
| 341 |
+
clip_skip=args.clip_skip,
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 345 |
+
encoder_hidden_states = train_util.get_hidden_states(
|
| 346 |
+
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 350 |
+
# with noise offset and/or multires noise if specified
|
| 351 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
| 352 |
+
|
| 353 |
+
# Predict the noise residual
|
| 354 |
+
with accelerator.autocast():
|
| 355 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 356 |
+
|
| 357 |
+
if args.v_parameterization:
|
| 358 |
+
# v-parameterization training
|
| 359 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 360 |
+
else:
|
| 361 |
+
target = noise
|
| 362 |
+
|
| 363 |
+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
| 364 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
| 365 |
+
loss = apply_masked_loss(loss, batch)
|
| 366 |
+
loss = loss.mean([1, 2, 3])
|
| 367 |
+
|
| 368 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 369 |
+
loss = loss * loss_weights
|
| 370 |
+
|
| 371 |
+
if args.min_snr_gamma:
|
| 372 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 373 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 374 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 375 |
+
if args.debiased_estimation_loss:
|
| 376 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 377 |
+
|
| 378 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 379 |
+
|
| 380 |
+
accelerator.backward(loss)
|
| 381 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 382 |
+
if train_text_encoder:
|
| 383 |
+
params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
|
| 384 |
+
else:
|
| 385 |
+
params_to_clip = unet.parameters()
|
| 386 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 387 |
+
|
| 388 |
+
optimizer.step()
|
| 389 |
+
lr_scheduler.step()
|
| 390 |
+
optimizer.zero_grad(set_to_none=True)
|
| 391 |
+
|
| 392 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 393 |
+
if accelerator.sync_gradients:
|
| 394 |
+
progress_bar.update(1)
|
| 395 |
+
global_step += 1
|
| 396 |
+
|
| 397 |
+
train_util.sample_images(
|
| 398 |
+
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# 指定ステップごとにモデルを保存
|
| 402 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 403 |
+
accelerator.wait_for_everyone()
|
| 404 |
+
if accelerator.is_main_process:
|
| 405 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 406 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 407 |
+
args,
|
| 408 |
+
False,
|
| 409 |
+
accelerator,
|
| 410 |
+
src_path,
|
| 411 |
+
save_stable_diffusion_format,
|
| 412 |
+
use_safetensors,
|
| 413 |
+
save_dtype,
|
| 414 |
+
epoch,
|
| 415 |
+
num_train_epochs,
|
| 416 |
+
global_step,
|
| 417 |
+
accelerator.unwrap_model(text_encoder),
|
| 418 |
+
accelerator.unwrap_model(unet),
|
| 419 |
+
vae,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
current_loss = loss.detach().item()
|
| 423 |
+
if args.logging_dir is not None:
|
| 424 |
+
logs = {"loss": current_loss}
|
| 425 |
+
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
| 426 |
+
accelerator.log(logs, step=global_step)
|
| 427 |
+
|
| 428 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 429 |
+
avr_loss: float = loss_recorder.moving_average
|
| 430 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 431 |
+
progress_bar.set_postfix(**logs)
|
| 432 |
+
|
| 433 |
+
if global_step >= args.max_train_steps:
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
if args.logging_dir is not None:
|
| 437 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 438 |
+
accelerator.log(logs, step=epoch + 1)
|
| 439 |
+
|
| 440 |
+
accelerator.wait_for_everyone()
|
| 441 |
+
|
| 442 |
+
if args.save_every_n_epochs is not None:
|
| 443 |
+
if accelerator.is_main_process:
|
| 444 |
+
# checking for saving is in util
|
| 445 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 446 |
+
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
| 447 |
+
args,
|
| 448 |
+
True,
|
| 449 |
+
accelerator,
|
| 450 |
+
src_path,
|
| 451 |
+
save_stable_diffusion_format,
|
| 452 |
+
use_safetensors,
|
| 453 |
+
save_dtype,
|
| 454 |
+
epoch,
|
| 455 |
+
num_train_epochs,
|
| 456 |
+
global_step,
|
| 457 |
+
accelerator.unwrap_model(text_encoder),
|
| 458 |
+
accelerator.unwrap_model(unet),
|
| 459 |
+
vae,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 463 |
+
|
| 464 |
+
is_main_process = accelerator.is_main_process
|
| 465 |
+
if is_main_process:
|
| 466 |
+
unet = accelerator.unwrap_model(unet)
|
| 467 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
| 468 |
+
|
| 469 |
+
accelerator.end_training()
|
| 470 |
+
|
| 471 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 472 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 473 |
+
|
| 474 |
+
del accelerator # この後メモリを使うのでこれは消す
|
| 475 |
+
|
| 476 |
+
if is_main_process:
|
| 477 |
+
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
| 478 |
+
train_util.save_sd_model_on_train_end(
|
| 479 |
+
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
| 480 |
+
)
|
| 481 |
+
logger.info("model saved.")
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 485 |
+
parser = argparse.ArgumentParser()
|
| 486 |
+
|
| 487 |
+
add_logging_arguments(parser)
|
| 488 |
+
train_util.add_sd_models_arguments(parser)
|
| 489 |
+
train_util.add_dataset_arguments(parser, True, False, True)
|
| 490 |
+
train_util.add_training_arguments(parser, True)
|
| 491 |
+
train_util.add_masked_loss_arguments(parser)
|
| 492 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 493 |
+
train_util.add_sd_saving_arguments(parser)
|
| 494 |
+
train_util.add_optimizer_arguments(parser)
|
| 495 |
+
config_util.add_config_arguments(parser)
|
| 496 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 497 |
+
|
| 498 |
+
parser.add_argument(
|
| 499 |
+
"--learning_rate_te",
|
| 500 |
+
type=float,
|
| 501 |
+
default=None,
|
| 502 |
+
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
| 503 |
+
)
|
| 504 |
+
parser.add_argument(
|
| 505 |
+
"--no_token_padding",
|
| 506 |
+
action="store_true",
|
| 507 |
+
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
|
| 508 |
+
)
|
| 509 |
+
parser.add_argument(
|
| 510 |
+
"--stop_text_encoder_training",
|
| 511 |
+
type=int,
|
| 512 |
+
default=None,
|
| 513 |
+
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
| 514 |
+
)
|
| 515 |
+
parser.add_argument(
|
| 516 |
+
"--no_half_vae",
|
| 517 |
+
action="store_true",
|
| 518 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
return parser
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
if __name__ == "__main__":
|
| 525 |
+
parser = setup_parser()
|
| 526 |
+
|
| 527 |
+
args = parser.parse_args()
|
| 528 |
+
train_util.verify_command_line_training_args(args)
|
| 529 |
+
args = train_util.read_config_from_file(args, parser)
|
| 530 |
+
|
| 531 |
+
train(args)
|
train_network.py
ADDED
|
@@ -0,0 +1,1250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
import json
|
| 9 |
+
from multiprocessing import Value
|
| 10 |
+
import toml
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 16 |
+
|
| 17 |
+
init_ipex()
|
| 18 |
+
|
| 19 |
+
from accelerate.utils import set_seed
|
| 20 |
+
from diffusers import DDPMScheduler
|
| 21 |
+
from library import deepspeed_utils, model_util
|
| 22 |
+
|
| 23 |
+
import library.train_util as train_util
|
| 24 |
+
from library.train_util import DreamBoothDataset
|
| 25 |
+
import library.config_util as config_util
|
| 26 |
+
from library.config_util import (
|
| 27 |
+
ConfigSanitizer,
|
| 28 |
+
BlueprintGenerator,
|
| 29 |
+
)
|
| 30 |
+
import library.huggingface_util as huggingface_util
|
| 31 |
+
import library.custom_train_functions as custom_train_functions
|
| 32 |
+
from library.custom_train_functions import (
|
| 33 |
+
apply_snr_weight,
|
| 34 |
+
get_weighted_text_embeddings,
|
| 35 |
+
prepare_scheduler_for_custom_training,
|
| 36 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 37 |
+
add_v_prediction_like_loss,
|
| 38 |
+
apply_debiased_estimation,
|
| 39 |
+
apply_masked_loss,
|
| 40 |
+
)
|
| 41 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 42 |
+
|
| 43 |
+
setup_logging()
|
| 44 |
+
import logging
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class NetworkTrainer:
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.vae_scale_factor = 0.18215
|
| 52 |
+
self.is_sdxl = False
|
| 53 |
+
|
| 54 |
+
# TODO 他のスクリプトと共通化する
|
| 55 |
+
def generate_step_logs(
|
| 56 |
+
self,
|
| 57 |
+
args: argparse.Namespace,
|
| 58 |
+
current_loss,
|
| 59 |
+
avr_loss,
|
| 60 |
+
lr_scheduler,
|
| 61 |
+
lr_descriptions,
|
| 62 |
+
keys_scaled=None,
|
| 63 |
+
mean_norm=None,
|
| 64 |
+
maximum_norm=None,
|
| 65 |
+
):
|
| 66 |
+
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
| 67 |
+
|
| 68 |
+
if keys_scaled is not None:
|
| 69 |
+
logs["max_norm/keys_scaled"] = keys_scaled
|
| 70 |
+
logs["max_norm/average_key_norm"] = mean_norm
|
| 71 |
+
logs["max_norm/max_key_norm"] = maximum_norm
|
| 72 |
+
|
| 73 |
+
lrs = lr_scheduler.get_last_lr()
|
| 74 |
+
for i, lr in enumerate(lrs):
|
| 75 |
+
if lr_descriptions is not None:
|
| 76 |
+
lr_desc = lr_descriptions[i]
|
| 77 |
+
else:
|
| 78 |
+
idx = i - (0 if args.network_train_unet_only else -1)
|
| 79 |
+
if idx == -1:
|
| 80 |
+
lr_desc = "textencoder"
|
| 81 |
+
else:
|
| 82 |
+
if len(lrs) > 2:
|
| 83 |
+
lr_desc = f"group{idx}"
|
| 84 |
+
else:
|
| 85 |
+
lr_desc = "unet"
|
| 86 |
+
|
| 87 |
+
logs[f"lr/{lr_desc}"] = lr
|
| 88 |
+
|
| 89 |
+
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
| 90 |
+
# tracking d*lr value
|
| 91 |
+
logs[f"lr/d*lr/{lr_desc}"] = (
|
| 92 |
+
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return logs
|
| 96 |
+
|
| 97 |
+
def assert_extra_args(self, args, train_dataset_group):
|
| 98 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
| 99 |
+
|
| 100 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
| 101 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 102 |
+
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
| 103 |
+
|
| 104 |
+
def load_tokenizer(self, args):
|
| 105 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 106 |
+
return tokenizer
|
| 107 |
+
|
| 108 |
+
def is_text_encoder_outputs_cached(self, args):
|
| 109 |
+
return False
|
| 110 |
+
|
| 111 |
+
def is_train_text_encoder(self, args):
|
| 112 |
+
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
| 113 |
+
|
| 114 |
+
def cache_text_encoder_outputs_if_needed(
|
| 115 |
+
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
| 116 |
+
):
|
| 117 |
+
for t_enc in text_encoders:
|
| 118 |
+
t_enc.to(accelerator.device, dtype=weight_dtype)
|
| 119 |
+
|
| 120 |
+
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
| 121 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 122 |
+
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
|
| 123 |
+
return encoder_hidden_states
|
| 124 |
+
|
| 125 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
| 126 |
+
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
| 127 |
+
return noise_pred
|
| 128 |
+
|
| 129 |
+
def all_reduce_network(self, accelerator, network):
|
| 130 |
+
for param in network.parameters():
|
| 131 |
+
if param.grad is not None:
|
| 132 |
+
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
| 133 |
+
|
| 134 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
| 135 |
+
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
| 136 |
+
|
| 137 |
+
def train(self, args):
|
| 138 |
+
session_id = random.randint(0, 2**32)
|
| 139 |
+
training_started_at = time.time()
|
| 140 |
+
train_util.verify_training_args(args)
|
| 141 |
+
train_util.prepare_dataset_args(args, True)
|
| 142 |
+
deepspeed_utils.prepare_deepspeed_args(args)
|
| 143 |
+
setup_logging(args, reset=True)
|
| 144 |
+
|
| 145 |
+
cache_latents = args.cache_latents
|
| 146 |
+
use_dreambooth_method = args.in_json is None
|
| 147 |
+
use_user_config = args.dataset_config is not None
|
| 148 |
+
|
| 149 |
+
if args.seed is None:
|
| 150 |
+
args.seed = random.randint(0, 2**32)
|
| 151 |
+
set_seed(args.seed)
|
| 152 |
+
|
| 153 |
+
# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
|
| 154 |
+
tokenizer = self.load_tokenizer(args)
|
| 155 |
+
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
| 156 |
+
|
| 157 |
+
# データセットを準備する
|
| 158 |
+
if args.dataset_class is None:
|
| 159 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
| 160 |
+
if use_user_config:
|
| 161 |
+
logger.info(f"Loading dataset config from {args.dataset_config}")
|
| 162 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 163 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
| 164 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 165 |
+
logger.warning(
|
| 166 |
+
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 167 |
+
", ".join(ignored)
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
if use_dreambooth_method:
|
| 172 |
+
logger.info("Using DreamBooth method.")
|
| 173 |
+
user_config = {
|
| 174 |
+
"datasets": [
|
| 175 |
+
{
|
| 176 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
| 177 |
+
args.train_data_dir, args.reg_data_dir
|
| 178 |
+
)
|
| 179 |
+
}
|
| 180 |
+
]
|
| 181 |
+
}
|
| 182 |
+
else:
|
| 183 |
+
logger.info("Training with captions.")
|
| 184 |
+
user_config = {
|
| 185 |
+
"datasets": [
|
| 186 |
+
{
|
| 187 |
+
"subsets": [
|
| 188 |
+
{
|
| 189 |
+
"image_dir": args.train_data_dir,
|
| 190 |
+
"metadata_file": args.in_json,
|
| 191 |
+
}
|
| 192 |
+
]
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 198 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 199 |
+
else:
|
| 200 |
+
# use arbitrary dataset class
|
| 201 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
| 202 |
+
|
| 203 |
+
current_epoch = Value("i", 0)
|
| 204 |
+
current_step = Value("i", 0)
|
| 205 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 206 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 207 |
+
|
| 208 |
+
if args.debug_dataset:
|
| 209 |
+
train_util.debug_dataset(train_dataset_group)
|
| 210 |
+
return
|
| 211 |
+
if len(train_dataset_group) == 0:
|
| 212 |
+
logger.error(
|
| 213 |
+
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
| 214 |
+
)
|
| 215 |
+
return
|
| 216 |
+
|
| 217 |
+
if cache_latents:
|
| 218 |
+
assert (
|
| 219 |
+
train_dataset_group.is_latent_cacheable()
|
| 220 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 221 |
+
|
| 222 |
+
self.assert_extra_args(args, train_dataset_group)
|
| 223 |
+
|
| 224 |
+
# acceleratorを準備する
|
| 225 |
+
logger.info("preparing accelerator")
|
| 226 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 227 |
+
is_main_process = accelerator.is_main_process
|
| 228 |
+
|
| 229 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 230 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 231 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 232 |
+
|
| 233 |
+
# モデルを読み込む
|
| 234 |
+
model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
| 235 |
+
|
| 236 |
+
# text_encoder is List[CLIPTextModel] or CLIPTextModel
|
| 237 |
+
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
| 238 |
+
|
| 239 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 240 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 241 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
| 242 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
| 243 |
+
|
| 244 |
+
# 差分追加学習のためにモデルを読み込む
|
| 245 |
+
sys.path.append(os.path.dirname(__file__))
|
| 246 |
+
accelerator.print("import network module:", args.network_module)
|
| 247 |
+
network_module = importlib.import_module(args.network_module)
|
| 248 |
+
|
| 249 |
+
if args.base_weights is not None:
|
| 250 |
+
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
| 251 |
+
for i, weight_path in enumerate(args.base_weights):
|
| 252 |
+
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
| 253 |
+
multiplier = 1.0
|
| 254 |
+
else:
|
| 255 |
+
multiplier = args.base_weights_multiplier[i]
|
| 256 |
+
|
| 257 |
+
accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
|
| 258 |
+
|
| 259 |
+
module, weights_sd = network_module.create_network_from_weights(
|
| 260 |
+
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
| 261 |
+
)
|
| 262 |
+
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
| 263 |
+
|
| 264 |
+
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
|
| 265 |
+
|
| 266 |
+
# 学習を準備する
|
| 267 |
+
if cache_latents:
|
| 268 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 269 |
+
vae.requires_grad_(False)
|
| 270 |
+
vae.eval()
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 273 |
+
vae.to("cpu")
|
| 274 |
+
clean_memory_on_device(accelerator.device)
|
| 275 |
+
|
| 276 |
+
accelerator.wait_for_everyone()
|
| 277 |
+
|
| 278 |
+
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
| 279 |
+
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
| 280 |
+
self.cache_text_encoder_outputs_if_needed(
|
| 281 |
+
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# prepare network
|
| 285 |
+
net_kwargs = {}
|
| 286 |
+
if args.network_args is not None:
|
| 287 |
+
for net_arg in args.network_args:
|
| 288 |
+
key, value = net_arg.split("=")
|
| 289 |
+
net_kwargs[key] = value
|
| 290 |
+
|
| 291 |
+
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
| 292 |
+
if args.dim_from_weights:
|
| 293 |
+
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
| 294 |
+
else:
|
| 295 |
+
if "dropout" not in net_kwargs:
|
| 296 |
+
# workaround for LyCORIS (;^ω^)
|
| 297 |
+
net_kwargs["dropout"] = args.network_dropout
|
| 298 |
+
|
| 299 |
+
network = network_module.create_network(
|
| 300 |
+
1.0,
|
| 301 |
+
args.network_dim,
|
| 302 |
+
args.network_alpha,
|
| 303 |
+
vae,
|
| 304 |
+
text_encoder,
|
| 305 |
+
unet,
|
| 306 |
+
neuron_dropout=args.network_dropout,
|
| 307 |
+
**net_kwargs,
|
| 308 |
+
)
|
| 309 |
+
if network is None:
|
| 310 |
+
return
|
| 311 |
+
network_has_multiplier = hasattr(network, "set_multiplier")
|
| 312 |
+
|
| 313 |
+
if hasattr(network, "prepare_network"):
|
| 314 |
+
network.prepare_network(args)
|
| 315 |
+
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
| 316 |
+
logger.warning(
|
| 317 |
+
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
| 318 |
+
)
|
| 319 |
+
args.scale_weight_norms = False
|
| 320 |
+
|
| 321 |
+
train_unet = not args.network_train_text_encoder_only
|
| 322 |
+
train_text_encoder = self.is_train_text_encoder(args)
|
| 323 |
+
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
| 324 |
+
|
| 325 |
+
if args.network_weights is not None:
|
| 326 |
+
# FIXME consider alpha of weights
|
| 327 |
+
info = network.load_weights(args.network_weights)
|
| 328 |
+
accelerator.print(f"load network weights from {args.network_weights}: {info}")
|
| 329 |
+
|
| 330 |
+
if args.gradient_checkpointing:
|
| 331 |
+
unet.enable_gradient_checkpointing()
|
| 332 |
+
for t_enc in text_encoders:
|
| 333 |
+
t_enc.gradient_checkpointing_enable()
|
| 334 |
+
del t_enc
|
| 335 |
+
network.enable_gradient_checkpointing() # may have no effect
|
| 336 |
+
|
| 337 |
+
# 学習に必要なクラスを準備する
|
| 338 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 339 |
+
|
| 340 |
+
# 後方互換性を確保するよ
|
| 341 |
+
try:
|
| 342 |
+
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
| 343 |
+
if type(results) is tuple:
|
| 344 |
+
trainable_params = results[0]
|
| 345 |
+
lr_descriptions = results[1]
|
| 346 |
+
else:
|
| 347 |
+
trainable_params = results
|
| 348 |
+
lr_descriptions = None
|
| 349 |
+
except TypeError as e:
|
| 350 |
+
# logger.warning(f"{e}")
|
| 351 |
+
# accelerator.print(
|
| 352 |
+
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
| 353 |
+
# )
|
| 354 |
+
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
| 355 |
+
lr_descriptions = None
|
| 356 |
+
|
| 357 |
+
# if len(trainable_params) == 0:
|
| 358 |
+
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
|
| 359 |
+
# for params in trainable_params:
|
| 360 |
+
# for k, v in params.items():
|
| 361 |
+
# if type(v) == float:
|
| 362 |
+
# pass
|
| 363 |
+
# else:
|
| 364 |
+
# v = len(v)
|
| 365 |
+
# accelerator.print(f"trainable_params: {k} = {v}")
|
| 366 |
+
|
| 367 |
+
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 368 |
+
|
| 369 |
+
# dataloaderを準備する
|
| 370 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 371 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 372 |
+
|
| 373 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 374 |
+
train_dataset_group,
|
| 375 |
+
batch_size=1,
|
| 376 |
+
shuffle=True,
|
| 377 |
+
collate_fn=collator,
|
| 378 |
+
num_workers=n_workers,
|
| 379 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# 学習ステップ数を計算する
|
| 383 |
+
if args.max_train_epochs is not None:
|
| 384 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 385 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 386 |
+
)
|
| 387 |
+
accelerator.print(
|
| 388 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
# データセット側にも学習ステップを送信
|
| 392 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 393 |
+
|
| 394 |
+
# lr schedulerを用意する
|
| 395 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 396 |
+
|
| 397 |
+
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
| 398 |
+
if args.full_fp16:
|
| 399 |
+
assert (
|
| 400 |
+
args.mixed_precision == "fp16"
|
| 401 |
+
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
| 402 |
+
accelerator.print("enable full fp16 training.")
|
| 403 |
+
network.to(weight_dtype)
|
| 404 |
+
elif args.full_bf16:
|
| 405 |
+
assert (
|
| 406 |
+
args.mixed_precision == "bf16"
|
| 407 |
+
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
| 408 |
+
accelerator.print("enable full bf16 training.")
|
| 409 |
+
network.to(weight_dtype)
|
| 410 |
+
|
| 411 |
+
unet_weight_dtype = te_weight_dtype = weight_dtype
|
| 412 |
+
# Experimental Feature: Put base model into fp8 to save vram
|
| 413 |
+
if args.fp8_base:
|
| 414 |
+
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
| 415 |
+
assert (
|
| 416 |
+
args.mixed_precision != "no"
|
| 417 |
+
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
| 418 |
+
accelerator.print("enable fp8 training.")
|
| 419 |
+
unet_weight_dtype = torch.float8_e4m3fn
|
| 420 |
+
te_weight_dtype = torch.float8_e4m3fn
|
| 421 |
+
|
| 422 |
+
unet.requires_grad_(False)
|
| 423 |
+
unet.to(dtype=unet_weight_dtype)
|
| 424 |
+
for t_enc in text_encoders:
|
| 425 |
+
t_enc.requires_grad_(False)
|
| 426 |
+
|
| 427 |
+
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
| 428 |
+
if t_enc.device.type != "cpu":
|
| 429 |
+
t_enc.to(dtype=te_weight_dtype)
|
| 430 |
+
# nn.Embedding not support FP8
|
| 431 |
+
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
| 432 |
+
|
| 433 |
+
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
| 434 |
+
if args.deepspeed:
|
| 435 |
+
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
| 436 |
+
args,
|
| 437 |
+
unet=unet if train_unet else None,
|
| 438 |
+
text_encoder1=text_encoders[0] if train_text_encoder else None,
|
| 439 |
+
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
|
| 440 |
+
network=network,
|
| 441 |
+
)
|
| 442 |
+
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 443 |
+
ds_model, optimizer, train_dataloader, lr_scheduler
|
| 444 |
+
)
|
| 445 |
+
training_model = ds_model
|
| 446 |
+
else:
|
| 447 |
+
if train_unet:
|
| 448 |
+
unet = accelerator.prepare(unet)
|
| 449 |
+
else:
|
| 450 |
+
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
| 451 |
+
if train_text_encoder:
|
| 452 |
+
if len(text_encoders) > 1:
|
| 453 |
+
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
| 454 |
+
else:
|
| 455 |
+
text_encoder = accelerator.prepare(text_encoder)
|
| 456 |
+
text_encoders = [text_encoder]
|
| 457 |
+
else:
|
| 458 |
+
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
| 459 |
+
|
| 460 |
+
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 461 |
+
network, optimizer, train_dataloader, lr_scheduler
|
| 462 |
+
)
|
| 463 |
+
training_model = network
|
| 464 |
+
|
| 465 |
+
if args.gradient_checkpointing:
|
| 466 |
+
# according to TI example in Diffusers, train is required
|
| 467 |
+
unet.train()
|
| 468 |
+
for t_enc in text_encoders:
|
| 469 |
+
t_enc.train()
|
| 470 |
+
|
| 471 |
+
# set top parameter requires_grad = True for gradient checkpointing works
|
| 472 |
+
if train_text_encoder:
|
| 473 |
+
t_enc.text_model.embeddings.requires_grad_(True)
|
| 474 |
+
|
| 475 |
+
else:
|
| 476 |
+
unet.eval()
|
| 477 |
+
for t_enc in text_encoders:
|
| 478 |
+
t_enc.eval()
|
| 479 |
+
|
| 480 |
+
del t_enc
|
| 481 |
+
|
| 482 |
+
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
| 483 |
+
|
| 484 |
+
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
| 485 |
+
vae.requires_grad_(False)
|
| 486 |
+
vae.eval()
|
| 487 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 488 |
+
|
| 489 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 490 |
+
if args.full_fp16:
|
| 491 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 492 |
+
|
| 493 |
+
# before resuming make hook for saving/loading to save/load the network weights only
|
| 494 |
+
def save_model_hook(models, weights, output_dir):
|
| 495 |
+
# pop weights of other models than network to save only network weights
|
| 496 |
+
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
|
| 497 |
+
if accelerator.is_main_process or args.deepspeed:
|
| 498 |
+
remove_indices = []
|
| 499 |
+
for i, model in enumerate(models):
|
| 500 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
| 501 |
+
remove_indices.append(i)
|
| 502 |
+
for i in reversed(remove_indices):
|
| 503 |
+
if len(weights) > i:
|
| 504 |
+
weights.pop(i)
|
| 505 |
+
# print(f"save model hook: {len(weights)} weights will be saved")
|
| 506 |
+
|
| 507 |
+
# save current ecpoch and step
|
| 508 |
+
train_state_file = os.path.join(output_dir, "train_state.json")
|
| 509 |
+
# +1 is needed because the state is saved before current_step is set from global_step
|
| 510 |
+
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
|
| 511 |
+
with open(train_state_file, "w", encoding="utf-8") as f:
|
| 512 |
+
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
|
| 513 |
+
|
| 514 |
+
steps_from_state = None
|
| 515 |
+
|
| 516 |
+
def load_model_hook(models, input_dir):
|
| 517 |
+
# remove models except network
|
| 518 |
+
remove_indices = []
|
| 519 |
+
for i, model in enumerate(models):
|
| 520 |
+
if not isinstance(model, type(accelerator.unwrap_model(network))):
|
| 521 |
+
remove_indices.append(i)
|
| 522 |
+
for i in reversed(remove_indices):
|
| 523 |
+
models.pop(i)
|
| 524 |
+
# print(f"load model hook: {len(models)} models will be loaded")
|
| 525 |
+
|
| 526 |
+
# load current epoch and step to
|
| 527 |
+
nonlocal steps_from_state
|
| 528 |
+
train_state_file = os.path.join(input_dir, "train_state.json")
|
| 529 |
+
if os.path.exists(train_state_file):
|
| 530 |
+
with open(train_state_file, "r", encoding="utf-8") as f:
|
| 531 |
+
data = json.load(f)
|
| 532 |
+
steps_from_state = data["current_step"]
|
| 533 |
+
logger.info(f"load train state from {train_state_file}: {data}")
|
| 534 |
+
|
| 535 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
| 536 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
| 537 |
+
|
| 538 |
+
# resumeする
|
| 539 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 540 |
+
|
| 541 |
+
# epoch数を計算する
|
| 542 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 543 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 544 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 545 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 546 |
+
|
| 547 |
+
# 学習する
|
| 548 |
+
# TODO: find a way to handle total batch size when there are multiple datasets
|
| 549 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 550 |
+
|
| 551 |
+
accelerator.print("running training / 学習開始")
|
| 552 |
+
accelerator.print(f" num train images * repeats / 学���画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 553 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 554 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 555 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 556 |
+
accelerator.print(
|
| 557 |
+
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
| 558 |
+
)
|
| 559 |
+
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
| 560 |
+
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 561 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 562 |
+
|
| 563 |
+
# TODO refactor metadata creation and move to util
|
| 564 |
+
metadata = {
|
| 565 |
+
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
| 566 |
+
"ss_training_started_at": training_started_at, # unix timestamp
|
| 567 |
+
"ss_output_name": args.output_name,
|
| 568 |
+
"ss_learning_rate": args.learning_rate,
|
| 569 |
+
"ss_text_encoder_lr": args.text_encoder_lr,
|
| 570 |
+
"ss_unet_lr": args.unet_lr,
|
| 571 |
+
"ss_num_train_images": train_dataset_group.num_train_images,
|
| 572 |
+
"ss_num_reg_images": train_dataset_group.num_reg_images,
|
| 573 |
+
"ss_num_batches_per_epoch": len(train_dataloader),
|
| 574 |
+
"ss_num_epochs": num_train_epochs,
|
| 575 |
+
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
| 576 |
+
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 577 |
+
"ss_max_train_steps": args.max_train_steps,
|
| 578 |
+
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
| 579 |
+
"ss_lr_scheduler": args.lr_scheduler,
|
| 580 |
+
"ss_network_module": args.network_module,
|
| 581 |
+
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
| 582 |
+
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
|
| 583 |
+
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
|
| 584 |
+
"ss_mixed_precision": args.mixed_precision,
|
| 585 |
+
"ss_full_fp16": bool(args.full_fp16),
|
| 586 |
+
"ss_v2": bool(args.v2),
|
| 587 |
+
"ss_base_model_version": model_version,
|
| 588 |
+
"ss_clip_skip": args.clip_skip,
|
| 589 |
+
"ss_max_token_length": args.max_token_length,
|
| 590 |
+
"ss_cache_latents": bool(args.cache_latents),
|
| 591 |
+
"ss_seed": args.seed,
|
| 592 |
+
"ss_lowram": args.lowram,
|
| 593 |
+
"ss_noise_offset": args.noise_offset,
|
| 594 |
+
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
| 595 |
+
"ss_multires_noise_discount": args.multires_noise_discount,
|
| 596 |
+
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
|
| 597 |
+
"ss_zero_terminal_snr": args.zero_terminal_snr,
|
| 598 |
+
"ss_training_comment": args.training_comment, # will not be updated after training
|
| 599 |
+
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
| 600 |
+
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
| 601 |
+
"ss_max_grad_norm": args.max_grad_norm,
|
| 602 |
+
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
| 603 |
+
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
| 604 |
+
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
| 605 |
+
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
| 606 |
+
"ss_prior_loss_weight": args.prior_loss_weight,
|
| 607 |
+
"ss_min_snr_gamma": args.min_snr_gamma,
|
| 608 |
+
"ss_scale_weight_norms": args.scale_weight_norms,
|
| 609 |
+
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
| 610 |
+
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
| 611 |
+
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
|
| 612 |
+
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
|
| 613 |
+
"ss_loss_type": args.loss_type,
|
| 614 |
+
"ss_huber_schedule": args.huber_schedule,
|
| 615 |
+
"ss_huber_c": args.huber_c,
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
if use_user_config:
|
| 619 |
+
# save metadata of multiple datasets
|
| 620 |
+
# NOTE: pack "ss_datasets" value as json one time
|
| 621 |
+
# or should also pack nested collections as json?
|
| 622 |
+
datasets_metadata = []
|
| 623 |
+
tag_frequency = {} # merge tag frequency for metadata editor
|
| 624 |
+
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
| 625 |
+
|
| 626 |
+
for dataset in train_dataset_group.datasets:
|
| 627 |
+
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
| 628 |
+
dataset_metadata = {
|
| 629 |
+
"is_dreambooth": is_dreambooth_dataset,
|
| 630 |
+
"batch_size_per_device": dataset.batch_size,
|
| 631 |
+
"num_train_images": dataset.num_train_images, # includes repeating
|
| 632 |
+
"num_reg_images": dataset.num_reg_images,
|
| 633 |
+
"resolution": (dataset.width, dataset.height),
|
| 634 |
+
"enable_bucket": bool(dataset.enable_bucket),
|
| 635 |
+
"min_bucket_reso": dataset.min_bucket_reso,
|
| 636 |
+
"max_bucket_reso": dataset.max_bucket_reso,
|
| 637 |
+
"tag_frequency": dataset.tag_frequency,
|
| 638 |
+
"bucket_info": dataset.bucket_info,
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
subsets_metadata = []
|
| 642 |
+
for subset in dataset.subsets:
|
| 643 |
+
subset_metadata = {
|
| 644 |
+
"img_count": subset.img_count,
|
| 645 |
+
"num_repeats": subset.num_repeats,
|
| 646 |
+
"color_aug": bool(subset.color_aug),
|
| 647 |
+
"flip_aug": bool(subset.flip_aug),
|
| 648 |
+
"random_crop": bool(subset.random_crop),
|
| 649 |
+
"shuffle_caption": bool(subset.shuffle_caption),
|
| 650 |
+
"keep_tokens": subset.keep_tokens,
|
| 651 |
+
"keep_tokens_separator": subset.keep_tokens_separator,
|
| 652 |
+
"secondary_separator": subset.secondary_separator,
|
| 653 |
+
"enable_wildcard": bool(subset.enable_wildcard),
|
| 654 |
+
"caption_prefix": subset.caption_prefix,
|
| 655 |
+
"caption_suffix": subset.caption_suffix,
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
image_dir_or_metadata_file = None
|
| 659 |
+
if subset.image_dir:
|
| 660 |
+
image_dir = os.path.basename(subset.image_dir)
|
| 661 |
+
subset_metadata["image_dir"] = image_dir
|
| 662 |
+
image_dir_or_metadata_file = image_dir
|
| 663 |
+
|
| 664 |
+
if is_dreambooth_dataset:
|
| 665 |
+
subset_metadata["class_tokens"] = subset.class_tokens
|
| 666 |
+
subset_metadata["is_reg"] = subset.is_reg
|
| 667 |
+
if subset.is_reg:
|
| 668 |
+
image_dir_or_metadata_file = None # not merging reg dataset
|
| 669 |
+
else:
|
| 670 |
+
metadata_file = os.path.basename(subset.metadata_file)
|
| 671 |
+
subset_metadata["metadata_file"] = metadata_file
|
| 672 |
+
image_dir_or_metadata_file = metadata_file # may overwrite
|
| 673 |
+
|
| 674 |
+
subsets_metadata.append(subset_metadata)
|
| 675 |
+
|
| 676 |
+
# merge dataset dir: not reg subset only
|
| 677 |
+
# TODO update additional-network extension to show detailed dataset config from metadata
|
| 678 |
+
if image_dir_or_metadata_file is not None:
|
| 679 |
+
# datasets may have a certain dir multiple times
|
| 680 |
+
v = image_dir_or_metadata_file
|
| 681 |
+
i = 2
|
| 682 |
+
while v in dataset_dirs_info:
|
| 683 |
+
v = image_dir_or_metadata_file + f" ({i})"
|
| 684 |
+
i += 1
|
| 685 |
+
image_dir_or_metadata_file = v
|
| 686 |
+
|
| 687 |
+
dataset_dirs_info[image_dir_or_metadata_file] = {
|
| 688 |
+
"n_repeats": subset.num_repeats,
|
| 689 |
+
"img_count": subset.img_count,
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
dataset_metadata["subsets"] = subsets_metadata
|
| 693 |
+
datasets_metadata.append(dataset_metadata)
|
| 694 |
+
|
| 695 |
+
# merge tag frequency:
|
| 696 |
+
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
|
| 697 |
+
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
|
| 698 |
+
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
|
| 699 |
+
# なので、ここで複数datasetの回数を合算してもあまり意味はない
|
| 700 |
+
if ds_dir_name in tag_frequency:
|
| 701 |
+
continue
|
| 702 |
+
tag_frequency[ds_dir_name] = ds_freq_for_dir
|
| 703 |
+
|
| 704 |
+
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
| 705 |
+
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
| 706 |
+
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
| 707 |
+
else:
|
| 708 |
+
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
| 709 |
+
assert (
|
| 710 |
+
len(train_dataset_group.datasets) == 1
|
| 711 |
+
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
| 712 |
+
|
| 713 |
+
dataset = train_dataset_group.datasets[0]
|
| 714 |
+
|
| 715 |
+
dataset_dirs_info = {}
|
| 716 |
+
reg_dataset_dirs_info = {}
|
| 717 |
+
if use_dreambooth_method:
|
| 718 |
+
for subset in dataset.subsets:
|
| 719 |
+
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
|
| 720 |
+
info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
|
| 721 |
+
else:
|
| 722 |
+
for subset in dataset.subsets:
|
| 723 |
+
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
|
| 724 |
+
"n_repeats": subset.num_repeats,
|
| 725 |
+
"img_count": subset.img_count,
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
metadata.update(
|
| 729 |
+
{
|
| 730 |
+
"ss_batch_size_per_device": args.train_batch_size,
|
| 731 |
+
"ss_total_batch_size": total_batch_size,
|
| 732 |
+
"ss_resolution": args.resolution,
|
| 733 |
+
"ss_color_aug": bool(args.color_aug),
|
| 734 |
+
"ss_flip_aug": bool(args.flip_aug),
|
| 735 |
+
"ss_random_crop": bool(args.random_crop),
|
| 736 |
+
"ss_shuffle_caption": bool(args.shuffle_caption),
|
| 737 |
+
"ss_enable_bucket": bool(dataset.enable_bucket),
|
| 738 |
+
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
|
| 739 |
+
"ss_min_bucket_reso": dataset.min_bucket_reso,
|
| 740 |
+
"ss_max_bucket_reso": dataset.max_bucket_reso,
|
| 741 |
+
"ss_keep_tokens": args.keep_tokens,
|
| 742 |
+
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
|
| 743 |
+
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
|
| 744 |
+
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
|
| 745 |
+
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
| 746 |
+
}
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
# add extra args
|
| 750 |
+
if args.network_args:
|
| 751 |
+
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
| 752 |
+
|
| 753 |
+
# model name and hash
|
| 754 |
+
if args.pretrained_model_name_or_path is not None:
|
| 755 |
+
sd_model_name = args.pretrained_model_name_or_path
|
| 756 |
+
if os.path.exists(sd_model_name):
|
| 757 |
+
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
| 758 |
+
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
| 759 |
+
sd_model_name = os.path.basename(sd_model_name)
|
| 760 |
+
metadata["ss_sd_model_name"] = sd_model_name
|
| 761 |
+
|
| 762 |
+
if args.vae is not None:
|
| 763 |
+
vae_name = args.vae
|
| 764 |
+
if os.path.exists(vae_name):
|
| 765 |
+
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
| 766 |
+
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
| 767 |
+
vae_name = os.path.basename(vae_name)
|
| 768 |
+
metadata["ss_vae_name"] = vae_name
|
| 769 |
+
|
| 770 |
+
metadata = {k: str(v) for k, v in metadata.items()}
|
| 771 |
+
|
| 772 |
+
# make minimum metadata for filtering
|
| 773 |
+
minimum_metadata = {}
|
| 774 |
+
for key in train_util.SS_METADATA_MINIMUM_KEYS:
|
| 775 |
+
if key in metadata:
|
| 776 |
+
minimum_metadata[key] = metadata[key]
|
| 777 |
+
|
| 778 |
+
# calculate steps to skip when resuming or starting from a specific step
|
| 779 |
+
initial_step = 0
|
| 780 |
+
if args.initial_epoch is not None or args.initial_step is not None:
|
| 781 |
+
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
|
| 782 |
+
if steps_from_state is not None:
|
| 783 |
+
logger.warning(
|
| 784 |
+
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
|
| 785 |
+
)
|
| 786 |
+
if args.initial_step is not None:
|
| 787 |
+
initial_step = args.initial_step
|
| 788 |
+
else:
|
| 789 |
+
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
|
| 790 |
+
initial_step = (args.initial_epoch - 1) * math.ceil(
|
| 791 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 792 |
+
)
|
| 793 |
+
else:
|
| 794 |
+
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
|
| 795 |
+
if steps_from_state is not None:
|
| 796 |
+
initial_step = steps_from_state
|
| 797 |
+
steps_from_state = None
|
| 798 |
+
|
| 799 |
+
if initial_step > 0:
|
| 800 |
+
assert (
|
| 801 |
+
args.max_train_steps > initial_step
|
| 802 |
+
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
| 803 |
+
|
| 804 |
+
progress_bar = tqdm(
|
| 805 |
+
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
epoch_to_start = 0
|
| 809 |
+
if initial_step > 0:
|
| 810 |
+
if args.skip_until_initial_step:
|
| 811 |
+
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
|
| 812 |
+
if not args.resume:
|
| 813 |
+
logger.info(
|
| 814 |
+
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr scheduler���最初から始まります"
|
| 815 |
+
)
|
| 816 |
+
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
|
| 817 |
+
initial_step *= args.gradient_accumulation_steps
|
| 818 |
+
|
| 819 |
+
# set epoch to start to make initial_step less than len(train_dataloader)
|
| 820 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 821 |
+
else:
|
| 822 |
+
# if not, only epoch no is skipped for informative purpose
|
| 823 |
+
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 824 |
+
initial_step = 0 # do not skip
|
| 825 |
+
|
| 826 |
+
global_step = 0
|
| 827 |
+
|
| 828 |
+
noise_scheduler = DDPMScheduler(
|
| 829 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 830 |
+
)
|
| 831 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 832 |
+
if args.zero_terminal_snr:
|
| 833 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 834 |
+
|
| 835 |
+
if accelerator.is_main_process:
|
| 836 |
+
init_kwargs = {}
|
| 837 |
+
if args.wandb_run_name:
|
| 838 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 839 |
+
if args.log_tracker_config is not None:
|
| 840 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 841 |
+
accelerator.init_trackers(
|
| 842 |
+
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
| 843 |
+
config=train_util.get_sanitized_config_or_none(args),
|
| 844 |
+
init_kwargs=init_kwargs,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
loss_recorder = train_util.LossRecorder()
|
| 848 |
+
del train_dataset_group
|
| 849 |
+
|
| 850 |
+
# callback for step start
|
| 851 |
+
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
| 852 |
+
on_step_start = accelerator.unwrap_model(network).on_step_start
|
| 853 |
+
else:
|
| 854 |
+
on_step_start = lambda *args, **kwargs: None
|
| 855 |
+
|
| 856 |
+
# function for saving/removing
|
| 857 |
+
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
| 858 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 859 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 860 |
+
|
| 861 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
| 862 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
| 863 |
+
metadata["ss_steps"] = str(steps)
|
| 864 |
+
metadata["ss_epoch"] = str(epoch_no)
|
| 865 |
+
|
| 866 |
+
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
| 867 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
| 868 |
+
metadata_to_save.update(sai_metadata)
|
| 869 |
+
|
| 870 |
+
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
| 871 |
+
if args.huggingface_repo_id is not None:
|
| 872 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 873 |
+
|
| 874 |
+
def remove_model(old_ckpt_name):
|
| 875 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 876 |
+
if os.path.exists(old_ckpt_file):
|
| 877 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
| 878 |
+
os.remove(old_ckpt_file)
|
| 879 |
+
|
| 880 |
+
# For --sample_at_first
|
| 881 |
+
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 882 |
+
|
| 883 |
+
# training loop
|
| 884 |
+
if initial_step > 0: # only if skip_until_initial_step is specified
|
| 885 |
+
for skip_epoch in range(epoch_to_start): # skip epochs
|
| 886 |
+
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
|
| 887 |
+
initial_step -= len(train_dataloader)
|
| 888 |
+
global_step = initial_step
|
| 889 |
+
|
| 890 |
+
for epoch in range(epoch_to_start, num_train_epochs):
|
| 891 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 892 |
+
current_epoch.value = epoch + 1
|
| 893 |
+
|
| 894 |
+
metadata["ss_epoch"] = str(epoch + 1)
|
| 895 |
+
|
| 896 |
+
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
| 897 |
+
|
| 898 |
+
skipped_dataloader = None
|
| 899 |
+
if initial_step > 0:
|
| 900 |
+
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
| 901 |
+
initial_step = 1
|
| 902 |
+
|
| 903 |
+
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
| 904 |
+
current_step.value = global_step
|
| 905 |
+
if initial_step > 0:
|
| 906 |
+
initial_step -= 1
|
| 907 |
+
continue
|
| 908 |
+
|
| 909 |
+
with accelerator.accumulate(training_model):
|
| 910 |
+
on_step_start(text_encoder, unet)
|
| 911 |
+
|
| 912 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 913 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 914 |
+
else:
|
| 915 |
+
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
| 916 |
+
with torch.no_grad():
|
| 917 |
+
# latentに変換
|
| 918 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
| 919 |
+
else:
|
| 920 |
+
chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)]
|
| 921 |
+
list_latents = []
|
| 922 |
+
for chunk in chunks:
|
| 923 |
+
with torch.no_grad():
|
| 924 |
+
# latentに変換
|
| 925 |
+
list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype))
|
| 926 |
+
latents = torch.cat(list_latents, dim=0)
|
| 927 |
+
# NaNが含まれていれば警告を表示し0に置き換える
|
| 928 |
+
if torch.any(torch.isnan(latents)):
|
| 929 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
| 930 |
+
latents = torch.nan_to_num(latents, 0, out=latents)
|
| 931 |
+
latents = latents * self.vae_scale_factor
|
| 932 |
+
|
| 933 |
+
# get multiplier for each sample
|
| 934 |
+
if network_has_multiplier:
|
| 935 |
+
multipliers = batch["network_multipliers"]
|
| 936 |
+
# if all multipliers are same, use single multiplier
|
| 937 |
+
if torch.all(multipliers == multipliers[0]):
|
| 938 |
+
multipliers = multipliers[0].item()
|
| 939 |
+
else:
|
| 940 |
+
raise NotImplementedError("multipliers for each sample is not supported yet")
|
| 941 |
+
# print(f"set multiplier: {multipliers}")
|
| 942 |
+
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
| 943 |
+
|
| 944 |
+
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
| 945 |
+
# Get the text embedding for conditioning
|
| 946 |
+
if args.weighted_captions:
|
| 947 |
+
text_encoder_conds = get_weighted_text_embeddings(
|
| 948 |
+
tokenizer,
|
| 949 |
+
text_encoder,
|
| 950 |
+
batch["captions"],
|
| 951 |
+
accelerator.device,
|
| 952 |
+
args.max_token_length // 75 if args.max_token_length else 1,
|
| 953 |
+
clip_skip=args.clip_skip,
|
| 954 |
+
)
|
| 955 |
+
else:
|
| 956 |
+
text_encoder_conds = self.get_text_cond(
|
| 957 |
+
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 961 |
+
# with noise offset and/or multires noise if specified
|
| 962 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
| 963 |
+
args, noise_scheduler, latents
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
# ensure the hidden state will require grad
|
| 967 |
+
if args.gradient_checkpointing:
|
| 968 |
+
for x in noisy_latents:
|
| 969 |
+
x.requires_grad_(True)
|
| 970 |
+
for t in text_encoder_conds:
|
| 971 |
+
t.requires_grad_(True)
|
| 972 |
+
|
| 973 |
+
# Predict the noise residual
|
| 974 |
+
with accelerator.autocast():
|
| 975 |
+
noise_pred = self.call_unet(
|
| 976 |
+
args,
|
| 977 |
+
accelerator,
|
| 978 |
+
unet,
|
| 979 |
+
noisy_latents.requires_grad_(train_unet),
|
| 980 |
+
timesteps,
|
| 981 |
+
text_encoder_conds,
|
| 982 |
+
batch,
|
| 983 |
+
weight_dtype,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
if args.v_parameterization:
|
| 987 |
+
# v-parameterization training
|
| 988 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 989 |
+
else:
|
| 990 |
+
target = noise
|
| 991 |
+
|
| 992 |
+
loss = train_util.conditional_loss(
|
| 993 |
+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
| 994 |
+
)
|
| 995 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
| 996 |
+
loss = apply_masked_loss(loss, batch)
|
| 997 |
+
loss = loss.mean([1, 2, 3])
|
| 998 |
+
|
| 999 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 1000 |
+
loss = loss * loss_weights
|
| 1001 |
+
|
| 1002 |
+
if args.min_snr_gamma:
|
| 1003 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 1004 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 1005 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 1006 |
+
if args.v_pred_like_loss:
|
| 1007 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
| 1008 |
+
if args.debiased_estimation_loss:
|
| 1009 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 1010 |
+
|
| 1011 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 1012 |
+
|
| 1013 |
+
accelerator.backward(loss)
|
| 1014 |
+
if accelerator.sync_gradients:
|
| 1015 |
+
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
| 1016 |
+
if args.max_grad_norm != 0.0:
|
| 1017 |
+
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
| 1018 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 1019 |
+
|
| 1020 |
+
optimizer.step()
|
| 1021 |
+
lr_scheduler.step()
|
| 1022 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1023 |
+
|
| 1024 |
+
if args.scale_weight_norms:
|
| 1025 |
+
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
| 1026 |
+
args.scale_weight_norms, accelerator.device
|
| 1027 |
+
)
|
| 1028 |
+
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
| 1029 |
+
else:
|
| 1030 |
+
keys_scaled, mean_norm, maximum_norm = None, None, None
|
| 1031 |
+
|
| 1032 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 1033 |
+
if accelerator.sync_gradients:
|
| 1034 |
+
progress_bar.update(1)
|
| 1035 |
+
global_step += 1
|
| 1036 |
+
|
| 1037 |
+
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 1038 |
+
|
| 1039 |
+
# 指定ステップごとにモデルを保存
|
| 1040 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 1041 |
+
accelerator.wait_for_everyone()
|
| 1042 |
+
if accelerator.is_main_process:
|
| 1043 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 1044 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
|
| 1045 |
+
|
| 1046 |
+
if args.save_state:
|
| 1047 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 1048 |
+
|
| 1049 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 1050 |
+
if remove_step_no is not None:
|
| 1051 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 1052 |
+
remove_model(remove_ckpt_name)
|
| 1053 |
+
|
| 1054 |
+
current_loss = loss.detach().item()
|
| 1055 |
+
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
| 1056 |
+
avr_loss: float = loss_recorder.moving_average
|
| 1057 |
+
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 1058 |
+
progress_bar.set_postfix(**logs)
|
| 1059 |
+
|
| 1060 |
+
if args.scale_weight_norms:
|
| 1061 |
+
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
| 1062 |
+
|
| 1063 |
+
if args.logging_dir is not None:
|
| 1064 |
+
logs = self.generate_step_logs(
|
| 1065 |
+
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
|
| 1066 |
+
)
|
| 1067 |
+
accelerator.log(logs, step=global_step)
|
| 1068 |
+
|
| 1069 |
+
if global_step >= args.max_train_steps:
|
| 1070 |
+
break
|
| 1071 |
+
|
| 1072 |
+
if args.logging_dir is not None:
|
| 1073 |
+
logs = {"loss/epoch": loss_recorder.moving_average}
|
| 1074 |
+
accelerator.log(logs, step=epoch + 1)
|
| 1075 |
+
|
| 1076 |
+
accelerator.wait_for_everyone()
|
| 1077 |
+
|
| 1078 |
+
# 指定エポックごとにモデルを保存
|
| 1079 |
+
if args.save_every_n_epochs is not None:
|
| 1080 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 1081 |
+
if is_main_process and saving:
|
| 1082 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 1083 |
+
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
|
| 1084 |
+
|
| 1085 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 1086 |
+
if remove_epoch_no is not None:
|
| 1087 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 1088 |
+
remove_model(remove_ckpt_name)
|
| 1089 |
+
|
| 1090 |
+
if args.save_state:
|
| 1091 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 1092 |
+
|
| 1093 |
+
self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
| 1094 |
+
|
| 1095 |
+
# end of epoch
|
| 1096 |
+
|
| 1097 |
+
# metadata["ss_epoch"] = str(num_train_epochs)
|
| 1098 |
+
metadata["ss_training_finished_at"] = str(time.time())
|
| 1099 |
+
|
| 1100 |
+
if is_main_process:
|
| 1101 |
+
network = accelerator.unwrap_model(network)
|
| 1102 |
+
|
| 1103 |
+
accelerator.end_training()
|
| 1104 |
+
|
| 1105 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 1106 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 1107 |
+
|
| 1108 |
+
if is_main_process:
|
| 1109 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 1110 |
+
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
| 1111 |
+
|
| 1112 |
+
logger.info("model saved.")
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 1116 |
+
parser = argparse.ArgumentParser()
|
| 1117 |
+
|
| 1118 |
+
add_logging_arguments(parser)
|
| 1119 |
+
train_util.add_sd_models_arguments(parser)
|
| 1120 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
| 1121 |
+
train_util.add_training_arguments(parser, True)
|
| 1122 |
+
train_util.add_masked_loss_arguments(parser)
|
| 1123 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 1124 |
+
train_util.add_optimizer_arguments(parser)
|
| 1125 |
+
config_util.add_config_arguments(parser)
|
| 1126 |
+
custom_train_functions.add_custom_train_arguments(parser)
|
| 1127 |
+
|
| 1128 |
+
parser.add_argument(
|
| 1129 |
+
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
| 1130 |
+
)
|
| 1131 |
+
parser.add_argument(
|
| 1132 |
+
"--save_model_as",
|
| 1133 |
+
type=str,
|
| 1134 |
+
default="safetensors",
|
| 1135 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 1136 |
+
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
| 1140 |
+
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
| 1141 |
+
|
| 1142 |
+
parser.add_argument(
|
| 1143 |
+
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
| 1144 |
+
)
|
| 1145 |
+
parser.add_argument(
|
| 1146 |
+
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
|
| 1147 |
+
)
|
| 1148 |
+
parser.add_argument(
|
| 1149 |
+
"--network_dim",
|
| 1150 |
+
type=int,
|
| 1151 |
+
default=None,
|
| 1152 |
+
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
|
| 1153 |
+
)
|
| 1154 |
+
parser.add_argument(
|
| 1155 |
+
"--network_alpha",
|
| 1156 |
+
type=float,
|
| 1157 |
+
default=1,
|
| 1158 |
+
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
|
| 1159 |
+
)
|
| 1160 |
+
parser.add_argument(
|
| 1161 |
+
"--network_dropout",
|
| 1162 |
+
type=float,
|
| 1163 |
+
default=None,
|
| 1164 |
+
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
| 1165 |
+
)
|
| 1166 |
+
parser.add_argument(
|
| 1167 |
+
"--network_args",
|
| 1168 |
+
type=str,
|
| 1169 |
+
default=None,
|
| 1170 |
+
nargs="*",
|
| 1171 |
+
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
| 1172 |
+
)
|
| 1173 |
+
parser.add_argument(
|
| 1174 |
+
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
|
| 1175 |
+
)
|
| 1176 |
+
parser.add_argument(
|
| 1177 |
+
"--network_train_text_encoder_only",
|
| 1178 |
+
action="store_true",
|
| 1179 |
+
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
|
| 1180 |
+
)
|
| 1181 |
+
parser.add_argument(
|
| 1182 |
+
"--training_comment",
|
| 1183 |
+
type=str,
|
| 1184 |
+
default=None,
|
| 1185 |
+
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
|
| 1186 |
+
)
|
| 1187 |
+
parser.add_argument(
|
| 1188 |
+
"--dim_from_weights",
|
| 1189 |
+
action="store_true",
|
| 1190 |
+
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
| 1191 |
+
)
|
| 1192 |
+
parser.add_argument(
|
| 1193 |
+
"--scale_weight_norms",
|
| 1194 |
+
type=float,
|
| 1195 |
+
default=None,
|
| 1196 |
+
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
| 1197 |
+
)
|
| 1198 |
+
parser.add_argument(
|
| 1199 |
+
"--base_weights",
|
| 1200 |
+
type=str,
|
| 1201 |
+
default=None,
|
| 1202 |
+
nargs="*",
|
| 1203 |
+
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
| 1204 |
+
)
|
| 1205 |
+
parser.add_argument(
|
| 1206 |
+
"--base_weights_multiplier",
|
| 1207 |
+
type=float,
|
| 1208 |
+
default=None,
|
| 1209 |
+
nargs="*",
|
| 1210 |
+
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
| 1211 |
+
)
|
| 1212 |
+
parser.add_argument(
|
| 1213 |
+
"--no_half_vae",
|
| 1214 |
+
action="store_true",
|
| 1215 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 1216 |
+
)
|
| 1217 |
+
parser.add_argument(
|
| 1218 |
+
"--skip_until_initial_step",
|
| 1219 |
+
action="store_true",
|
| 1220 |
+
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
|
| 1221 |
+
)
|
| 1222 |
+
parser.add_argument(
|
| 1223 |
+
"--initial_epoch",
|
| 1224 |
+
type=int,
|
| 1225 |
+
default=None,
|
| 1226 |
+
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
|
| 1227 |
+
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
|
| 1228 |
+
)
|
| 1229 |
+
parser.add_argument(
|
| 1230 |
+
"--initial_step",
|
| 1231 |
+
type=int,
|
| 1232 |
+
default=None,
|
| 1233 |
+
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
|
| 1234 |
+
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
|
| 1235 |
+
)
|
| 1236 |
+
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
|
| 1237 |
+
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
|
| 1238 |
+
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
|
| 1239 |
+
return parser
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
+
if __name__ == "__main__":
|
| 1243 |
+
parser = setup_parser()
|
| 1244 |
+
|
| 1245 |
+
args = parser.parse_args()
|
| 1246 |
+
train_util.verify_command_line_training_args(args)
|
| 1247 |
+
args = train_util.read_config_from_file(args, parser)
|
| 1248 |
+
|
| 1249 |
+
trainer = NetworkTrainer()
|
| 1250 |
+
trainer.train(args)
|
train_textual_inversion.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from multiprocessing import Value
|
| 5 |
+
import toml
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
init_ipex()
|
| 14 |
+
|
| 15 |
+
from accelerate.utils import set_seed
|
| 16 |
+
from diffusers import DDPMScheduler
|
| 17 |
+
from transformers import CLIPTokenizer
|
| 18 |
+
from library import deepspeed_utils, model_util
|
| 19 |
+
|
| 20 |
+
import library.train_util as train_util
|
| 21 |
+
import library.huggingface_util as huggingface_util
|
| 22 |
+
import library.config_util as config_util
|
| 23 |
+
from library.config_util import (
|
| 24 |
+
ConfigSanitizer,
|
| 25 |
+
BlueprintGenerator,
|
| 26 |
+
)
|
| 27 |
+
import library.custom_train_functions as custom_train_functions
|
| 28 |
+
from library.custom_train_functions import (
|
| 29 |
+
apply_snr_weight,
|
| 30 |
+
prepare_scheduler_for_custom_training,
|
| 31 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 32 |
+
add_v_prediction_like_loss,
|
| 33 |
+
apply_debiased_estimation,
|
| 34 |
+
apply_masked_loss,
|
| 35 |
+
)
|
| 36 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 37 |
+
|
| 38 |
+
setup_logging()
|
| 39 |
+
import logging
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
imagenet_templates_small = [
|
| 44 |
+
"a photo of a {}",
|
| 45 |
+
"a rendering of a {}",
|
| 46 |
+
"a cropped photo of the {}",
|
| 47 |
+
"the photo of a {}",
|
| 48 |
+
"a photo of a clean {}",
|
| 49 |
+
"a photo of a dirty {}",
|
| 50 |
+
"a dark photo of the {}",
|
| 51 |
+
"a photo of my {}",
|
| 52 |
+
"a photo of the cool {}",
|
| 53 |
+
"a close-up photo of a {}",
|
| 54 |
+
"a bright photo of the {}",
|
| 55 |
+
"a cropped photo of a {}",
|
| 56 |
+
"a photo of the {}",
|
| 57 |
+
"a good photo of the {}",
|
| 58 |
+
"a photo of one {}",
|
| 59 |
+
"a close-up photo of the {}",
|
| 60 |
+
"a rendition of the {}",
|
| 61 |
+
"a photo of the clean {}",
|
| 62 |
+
"a rendition of a {}",
|
| 63 |
+
"a photo of a nice {}",
|
| 64 |
+
"a good photo of a {}",
|
| 65 |
+
"a photo of the nice {}",
|
| 66 |
+
"a photo of the small {}",
|
| 67 |
+
"a photo of the weird {}",
|
| 68 |
+
"a photo of the large {}",
|
| 69 |
+
"a photo of a cool {}",
|
| 70 |
+
"a photo of a small {}",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
imagenet_style_templates_small = [
|
| 74 |
+
"a painting in the style of {}",
|
| 75 |
+
"a rendering in the style of {}",
|
| 76 |
+
"a cropped painting in the style of {}",
|
| 77 |
+
"the painting in the style of {}",
|
| 78 |
+
"a clean painting in the style of {}",
|
| 79 |
+
"a dirty painting in the style of {}",
|
| 80 |
+
"a dark painting in the style of {}",
|
| 81 |
+
"a picture in the style of {}",
|
| 82 |
+
"a cool painting in the style of {}",
|
| 83 |
+
"a close-up painting in the style of {}",
|
| 84 |
+
"a bright painting in the style of {}",
|
| 85 |
+
"a cropped painting in the style of {}",
|
| 86 |
+
"a good painting in the style of {}",
|
| 87 |
+
"a close-up painting in the style of {}",
|
| 88 |
+
"a rendition in the style of {}",
|
| 89 |
+
"a nice painting in the style of {}",
|
| 90 |
+
"a small painting in the style of {}",
|
| 91 |
+
"a weird painting in the style of {}",
|
| 92 |
+
"a large painting in the style of {}",
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class TextualInversionTrainer:
|
| 97 |
+
def __init__(self):
|
| 98 |
+
self.vae_scale_factor = 0.18215
|
| 99 |
+
self.is_sdxl = False
|
| 100 |
+
|
| 101 |
+
def assert_extra_args(self, args, train_dataset_group):
|
| 102 |
+
train_dataset_group.verify_bucket_reso_steps(64)
|
| 103 |
+
|
| 104 |
+
def load_target_model(self, args, weight_dtype, accelerator):
|
| 105 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 106 |
+
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
|
| 107 |
+
|
| 108 |
+
def load_tokenizer(self, args):
|
| 109 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 110 |
+
return tokenizer
|
| 111 |
+
|
| 112 |
+
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
| 116 |
+
with torch.enable_grad():
|
| 117 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 118 |
+
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
|
| 119 |
+
return encoder_hidden_states
|
| 120 |
+
|
| 121 |
+
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
| 122 |
+
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
| 123 |
+
return noise_pred
|
| 124 |
+
|
| 125 |
+
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
|
| 126 |
+
train_util.sample_images(
|
| 127 |
+
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def save_weights(self, file, updated_embs, save_dtype, metadata):
|
| 131 |
+
state_dict = {"emb_params": updated_embs[0]}
|
| 132 |
+
|
| 133 |
+
if save_dtype is not None:
|
| 134 |
+
for key in list(state_dict.keys()):
|
| 135 |
+
v = state_dict[key]
|
| 136 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
| 137 |
+
state_dict[key] = v
|
| 138 |
+
|
| 139 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 140 |
+
from safetensors.torch import save_file
|
| 141 |
+
|
| 142 |
+
save_file(state_dict, file, metadata)
|
| 143 |
+
else:
|
| 144 |
+
torch.save(state_dict, file) # can be loaded in Web UI
|
| 145 |
+
|
| 146 |
+
def load_weights(self, file):
|
| 147 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 148 |
+
from safetensors.torch import load_file
|
| 149 |
+
|
| 150 |
+
data = load_file(file)
|
| 151 |
+
else:
|
| 152 |
+
# compatible to Web UI's file format
|
| 153 |
+
data = torch.load(file, map_location="cpu")
|
| 154 |
+
if type(data) != dict:
|
| 155 |
+
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
| 156 |
+
|
| 157 |
+
if "string_to_param" in data: # textual inversion embeddings
|
| 158 |
+
data = data["string_to_param"]
|
| 159 |
+
if hasattr(data, "_parameters"): # support old PyTorch?
|
| 160 |
+
data = getattr(data, "_parameters")
|
| 161 |
+
|
| 162 |
+
emb = next(iter(data.values()))
|
| 163 |
+
if type(emb) != torch.Tensor:
|
| 164 |
+
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
| 165 |
+
|
| 166 |
+
if len(emb.size()) == 1:
|
| 167 |
+
emb = emb.unsqueeze(0)
|
| 168 |
+
|
| 169 |
+
return [emb]
|
| 170 |
+
|
| 171 |
+
def train(self, args):
|
| 172 |
+
if args.output_name is None:
|
| 173 |
+
args.output_name = args.token_string
|
| 174 |
+
use_template = args.use_object_template or args.use_style_template
|
| 175 |
+
|
| 176 |
+
train_util.verify_training_args(args)
|
| 177 |
+
train_util.prepare_dataset_args(args, True)
|
| 178 |
+
setup_logging(args, reset=True)
|
| 179 |
+
|
| 180 |
+
cache_latents = args.cache_latents
|
| 181 |
+
|
| 182 |
+
if args.seed is not None:
|
| 183 |
+
set_seed(args.seed)
|
| 184 |
+
|
| 185 |
+
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
|
| 186 |
+
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
| 187 |
+
|
| 188 |
+
# acceleratorを準備する
|
| 189 |
+
logger.info("prepare accelerator")
|
| 190 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 191 |
+
|
| 192 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 193 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 194 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 195 |
+
|
| 196 |
+
# モデルを読み込む
|
| 197 |
+
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
| 198 |
+
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
|
| 199 |
+
|
| 200 |
+
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
|
| 201 |
+
accelerator.print(
|
| 202 |
+
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
|
| 203 |
+
+ "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Convert the init_word to token_id
|
| 207 |
+
init_token_ids_list = []
|
| 208 |
+
if args.init_word is not None:
|
| 209 |
+
for i, tokenizer in enumerate(tokenizers):
|
| 210 |
+
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
| 211 |
+
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
| 212 |
+
accelerator.print(
|
| 213 |
+
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
|
| 214 |
+
+ f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
|
| 215 |
+
)
|
| 216 |
+
init_token_ids_list.append(init_token_ids)
|
| 217 |
+
else:
|
| 218 |
+
init_token_ids_list = [None] * len(tokenizers)
|
| 219 |
+
|
| 220 |
+
# tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
|
| 221 |
+
# token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
|
| 222 |
+
# add new word to tokenizer, count is num_vectors_per_token
|
| 223 |
+
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
| 224 |
+
|
| 225 |
+
self.assert_token_string(args.token_string, tokenizers)
|
| 226 |
+
|
| 227 |
+
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
| 228 |
+
token_ids_list = []
|
| 229 |
+
token_embeds_list = []
|
| 230 |
+
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
| 231 |
+
num_added_tokens = tokenizer.add_tokens(token_strings)
|
| 232 |
+
assert (
|
| 233 |
+
num_added_tokens == args.num_vectors_per_token
|
| 234 |
+
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}"
|
| 235 |
+
|
| 236 |
+
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
| 237 |
+
accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}")
|
| 238 |
+
assert (
|
| 239 |
+
min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
|
| 240 |
+
), f"token ids is not ordered : tokenizer {i+1}, {token_ids}"
|
| 241 |
+
assert (
|
| 242 |
+
len(tokenizer) - 1 == token_ids[-1]
|
| 243 |
+
), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"
|
| 244 |
+
token_ids_list.append(token_ids)
|
| 245 |
+
|
| 246 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
| 247 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 248 |
+
|
| 249 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
| 250 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
| 251 |
+
if init_token_ids is not None:
|
| 252 |
+
for i, token_id in enumerate(token_ids):
|
| 253 |
+
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
| 254 |
+
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
| 255 |
+
token_embeds_list.append(token_embeds)
|
| 256 |
+
|
| 257 |
+
# load weights
|
| 258 |
+
if args.weights is not None:
|
| 259 |
+
embeddings_list = self.load_weights(args.weights)
|
| 260 |
+
assert len(token_ids) == len(
|
| 261 |
+
embeddings_list[0]
|
| 262 |
+
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
| 263 |
+
# accelerator.print(token_ids, embeddings.size())
|
| 264 |
+
for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list):
|
| 265 |
+
for token_id, embedding in zip(token_ids, embeddings):
|
| 266 |
+
token_embeds[token_id] = embedding
|
| 267 |
+
# accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
| 268 |
+
accelerator.print(f"weighs loaded")
|
| 269 |
+
|
| 270 |
+
accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
| 271 |
+
|
| 272 |
+
# データセットを準備する
|
| 273 |
+
if args.dataset_class is None:
|
| 274 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
|
| 275 |
+
if args.dataset_config is not None:
|
| 276 |
+
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
| 277 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 278 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
| 279 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 280 |
+
accelerator.print(
|
| 281 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 282 |
+
", ".join(ignored)
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
use_dreambooth_method = args.in_json is None
|
| 287 |
+
if use_dreambooth_method:
|
| 288 |
+
accelerator.print("Use DreamBooth method.")
|
| 289 |
+
user_config = {
|
| 290 |
+
"datasets": [
|
| 291 |
+
{
|
| 292 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
| 293 |
+
args.train_data_dir, args.reg_data_dir
|
| 294 |
+
)
|
| 295 |
+
}
|
| 296 |
+
]
|
| 297 |
+
}
|
| 298 |
+
else:
|
| 299 |
+
logger.info("Train with captions.")
|
| 300 |
+
user_config = {
|
| 301 |
+
"datasets": [
|
| 302 |
+
{
|
| 303 |
+
"subsets": [
|
| 304 |
+
{
|
| 305 |
+
"image_dir": args.train_data_dir,
|
| 306 |
+
"metadata_file": args.in_json,
|
| 307 |
+
}
|
| 308 |
+
]
|
| 309 |
+
}
|
| 310 |
+
]
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
|
| 314 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 315 |
+
else:
|
| 316 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
|
| 317 |
+
|
| 318 |
+
self.assert_extra_args(args, train_dataset_group)
|
| 319 |
+
|
| 320 |
+
current_epoch = Value("i", 0)
|
| 321 |
+
current_step = Value("i", 0)
|
| 322 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 323 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 324 |
+
|
| 325 |
+
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
| 326 |
+
if use_template:
|
| 327 |
+
accelerator.print(f"use template for training captions. is object: {args.use_object_template}")
|
| 328 |
+
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
| 329 |
+
replace_to = " ".join(token_strings)
|
| 330 |
+
captions = []
|
| 331 |
+
for tmpl in templates:
|
| 332 |
+
captions.append(tmpl.format(replace_to))
|
| 333 |
+
train_dataset_group.add_replacement("", captions)
|
| 334 |
+
|
| 335 |
+
# サンプル生成用
|
| 336 |
+
if args.num_vectors_per_token > 1:
|
| 337 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 338 |
+
else:
|
| 339 |
+
prompt_replacement = None
|
| 340 |
+
else:
|
| 341 |
+
# サンプル生成用
|
| 342 |
+
if args.num_vectors_per_token > 1:
|
| 343 |
+
replace_to = " ".join(token_strings)
|
| 344 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
| 345 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 346 |
+
else:
|
| 347 |
+
prompt_replacement = None
|
| 348 |
+
|
| 349 |
+
if args.debug_dataset:
|
| 350 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
| 351 |
+
return
|
| 352 |
+
if len(train_dataset_group) == 0:
|
| 353 |
+
accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
| 354 |
+
return
|
| 355 |
+
|
| 356 |
+
if cache_latents:
|
| 357 |
+
assert (
|
| 358 |
+
train_dataset_group.is_latent_cacheable()
|
| 359 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 360 |
+
|
| 361 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 362 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 363 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
| 364 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
| 365 |
+
|
| 366 |
+
# 学習を準備する
|
| 367 |
+
if cache_latents:
|
| 368 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 369 |
+
vae.requires_grad_(False)
|
| 370 |
+
vae.eval()
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 373 |
+
vae.to("cpu")
|
| 374 |
+
clean_memory_on_device(accelerator.device)
|
| 375 |
+
|
| 376 |
+
accelerator.wait_for_everyone()
|
| 377 |
+
|
| 378 |
+
if args.gradient_checkpointing:
|
| 379 |
+
unet.enable_gradient_checkpointing()
|
| 380 |
+
for text_encoder in text_encoders:
|
| 381 |
+
text_encoder.gradient_checkpointing_enable()
|
| 382 |
+
|
| 383 |
+
# 学習に必要なクラスを準備する
|
| 384 |
+
accelerator.print("prepare optimizer, data loader etc.")
|
| 385 |
+
trainable_params = []
|
| 386 |
+
for text_encoder in text_encoders:
|
| 387 |
+
trainable_params += text_encoder.get_input_embeddings().parameters()
|
| 388 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 389 |
+
|
| 390 |
+
# dataloaderを準備する
|
| 391 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 392 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 393 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 394 |
+
train_dataset_group,
|
| 395 |
+
batch_size=1,
|
| 396 |
+
shuffle=True,
|
| 397 |
+
collate_fn=collator,
|
| 398 |
+
num_workers=n_workers,
|
| 399 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# 学習ステップ数を計算する
|
| 403 |
+
if args.max_train_epochs is not None:
|
| 404 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 405 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 406 |
+
)
|
| 407 |
+
accelerator.print(
|
| 408 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# データセット側にも学習ステップを送信
|
| 412 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 413 |
+
|
| 414 |
+
# lr schedulerを用意する
|
| 415 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 416 |
+
|
| 417 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 418 |
+
if len(text_encoders) == 1:
|
| 419 |
+
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 420 |
+
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
elif len(text_encoders) == 2:
|
| 424 |
+
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 425 |
+
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
| 429 |
+
|
| 430 |
+
else:
|
| 431 |
+
raise NotImplementedError()
|
| 432 |
+
|
| 433 |
+
index_no_updates_list = []
|
| 434 |
+
orig_embeds_params_list = []
|
| 435 |
+
for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
|
| 436 |
+
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
| 437 |
+
index_no_updates_list.append(index_no_updates)
|
| 438 |
+
|
| 439 |
+
# accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
|
| 440 |
+
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
| 441 |
+
orig_embeds_params_list.append(orig_embeds_params)
|
| 442 |
+
|
| 443 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
| 444 |
+
text_encoder.requires_grad_(True)
|
| 445 |
+
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
|
| 446 |
+
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
|
| 447 |
+
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
| 448 |
+
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
| 449 |
+
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
| 450 |
+
|
| 451 |
+
unet.requires_grad_(False)
|
| 452 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 453 |
+
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
| 454 |
+
# TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
|
| 455 |
+
unet.train()
|
| 456 |
+
else:
|
| 457 |
+
unet.eval()
|
| 458 |
+
|
| 459 |
+
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
| 460 |
+
vae.requires_grad_(False)
|
| 461 |
+
vae.eval()
|
| 462 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 463 |
+
|
| 464 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 465 |
+
if args.full_fp16:
|
| 466 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 467 |
+
for text_encoder in text_encoders:
|
| 468 |
+
text_encoder.to(weight_dtype)
|
| 469 |
+
if args.full_bf16:
|
| 470 |
+
for text_encoder in text_encoders:
|
| 471 |
+
text_encoder.to(weight_dtype)
|
| 472 |
+
|
| 473 |
+
# resumeする
|
| 474 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 475 |
+
|
| 476 |
+
# epoch数を計算する
|
| 477 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 478 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 479 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 480 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 481 |
+
|
| 482 |
+
# 学習する
|
| 483 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 484 |
+
accelerator.print("running training / 学習開始")
|
| 485 |
+
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 486 |
+
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 487 |
+
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 488 |
+
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
| 489 |
+
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
| 490 |
+
accelerator.print(
|
| 491 |
+
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
| 492 |
+
)
|
| 493 |
+
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 494 |
+
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 495 |
+
|
| 496 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 497 |
+
global_step = 0
|
| 498 |
+
|
| 499 |
+
noise_scheduler = DDPMScheduler(
|
| 500 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 501 |
+
)
|
| 502 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 503 |
+
if args.zero_terminal_snr:
|
| 504 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 505 |
+
|
| 506 |
+
if accelerator.is_main_process:
|
| 507 |
+
init_kwargs = {}
|
| 508 |
+
if args.wandb_run_name:
|
| 509 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 510 |
+
if args.log_tracker_config is not None:
|
| 511 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 512 |
+
accelerator.init_trackers(
|
| 513 |
+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# function for saving/removing
|
| 517 |
+
def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
|
| 518 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 519 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 520 |
+
|
| 521 |
+
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
| 522 |
+
|
| 523 |
+
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, False, True)
|
| 524 |
+
|
| 525 |
+
self.save_weights(ckpt_file, embs_list, save_dtype, sai_metadata)
|
| 526 |
+
if args.huggingface_repo_id is not None:
|
| 527 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 528 |
+
|
| 529 |
+
def remove_model(old_ckpt_name):
|
| 530 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 531 |
+
if os.path.exists(old_ckpt_file):
|
| 532 |
+
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
| 533 |
+
os.remove(old_ckpt_file)
|
| 534 |
+
|
| 535 |
+
# For --sample_at_first
|
| 536 |
+
self.sample_images(
|
| 537 |
+
accelerator,
|
| 538 |
+
args,
|
| 539 |
+
0,
|
| 540 |
+
global_step,
|
| 541 |
+
accelerator.device,
|
| 542 |
+
vae,
|
| 543 |
+
tokenizer_or_list,
|
| 544 |
+
text_encoder_or_list,
|
| 545 |
+
unet,
|
| 546 |
+
prompt_replacement,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# training loop
|
| 550 |
+
for epoch in range(num_train_epochs):
|
| 551 |
+
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
| 552 |
+
current_epoch.value = epoch + 1
|
| 553 |
+
|
| 554 |
+
for text_encoder in text_encoders:
|
| 555 |
+
text_encoder.train()
|
| 556 |
+
|
| 557 |
+
loss_total = 0
|
| 558 |
+
|
| 559 |
+
for step, batch in enumerate(train_dataloader):
|
| 560 |
+
current_step.value = global_step
|
| 561 |
+
with accelerator.accumulate(text_encoders[0]):
|
| 562 |
+
with torch.no_grad():
|
| 563 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 564 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 565 |
+
else:
|
| 566 |
+
# latentに変換
|
| 567 |
+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
| 568 |
+
latents = latents * self.vae_scale_factor
|
| 569 |
+
|
| 570 |
+
# Get the text embedding for conditioning
|
| 571 |
+
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
|
| 572 |
+
|
| 573 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 574 |
+
# with noise offset and/or multires noise if specified
|
| 575 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
|
| 576 |
+
args, noise_scheduler, latents
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Predict the noise residual
|
| 580 |
+
with accelerator.autocast():
|
| 581 |
+
noise_pred = self.call_unet(
|
| 582 |
+
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if args.v_parameterization:
|
| 586 |
+
# v-parameterization training
|
| 587 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 588 |
+
else:
|
| 589 |
+
target = noise
|
| 590 |
+
|
| 591 |
+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
| 592 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
| 593 |
+
loss = apply_masked_loss(loss, batch)
|
| 594 |
+
loss = loss.mean([1, 2, 3])
|
| 595 |
+
|
| 596 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 597 |
+
loss = loss * loss_weights
|
| 598 |
+
|
| 599 |
+
if args.min_snr_gamma:
|
| 600 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 601 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 602 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 603 |
+
if args.v_pred_like_loss:
|
| 604 |
+
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
| 605 |
+
if args.debiased_estimation_loss:
|
| 606 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 607 |
+
|
| 608 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 609 |
+
|
| 610 |
+
accelerator.backward(loss)
|
| 611 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 612 |
+
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
|
| 613 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 614 |
+
|
| 615 |
+
optimizer.step()
|
| 616 |
+
lr_scheduler.step()
|
| 617 |
+
optimizer.zero_grad(set_to_none=True)
|
| 618 |
+
|
| 619 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
for text_encoder, orig_embeds_params, index_no_updates in zip(
|
| 622 |
+
text_encoders, orig_embeds_params_list, index_no_updates_list
|
| 623 |
+
):
|
| 624 |
+
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
|
| 625 |
+
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
|
| 626 |
+
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
|
| 627 |
+
index_no_updates
|
| 628 |
+
]
|
| 629 |
+
|
| 630 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 631 |
+
if accelerator.sync_gradients:
|
| 632 |
+
progress_bar.update(1)
|
| 633 |
+
global_step += 1
|
| 634 |
+
|
| 635 |
+
self.sample_images(
|
| 636 |
+
accelerator,
|
| 637 |
+
args,
|
| 638 |
+
None,
|
| 639 |
+
global_step,
|
| 640 |
+
accelerator.device,
|
| 641 |
+
vae,
|
| 642 |
+
tokenizer_or_list,
|
| 643 |
+
text_encoder_or_list,
|
| 644 |
+
unet,
|
| 645 |
+
prompt_replacement,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
# 指定ステップごとにモデルを保存
|
| 649 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 650 |
+
accelerator.wait_for_everyone()
|
| 651 |
+
if accelerator.is_main_process:
|
| 652 |
+
updated_embs_list = []
|
| 653 |
+
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
| 654 |
+
updated_embs = (
|
| 655 |
+
accelerator.unwrap_model(text_encoder)
|
| 656 |
+
.get_input_embeddings()
|
| 657 |
+
.weight[token_ids]
|
| 658 |
+
.data.detach()
|
| 659 |
+
.clone()
|
| 660 |
+
)
|
| 661 |
+
updated_embs_list.append(updated_embs)
|
| 662 |
+
|
| 663 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 664 |
+
save_model(ckpt_name, updated_embs_list, global_step, epoch)
|
| 665 |
+
|
| 666 |
+
if args.save_state:
|
| 667 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 668 |
+
|
| 669 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 670 |
+
if remove_step_no is not None:
|
| 671 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 672 |
+
remove_model(remove_ckpt_name)
|
| 673 |
+
|
| 674 |
+
current_loss = loss.detach().item()
|
| 675 |
+
if args.logging_dir is not None:
|
| 676 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
| 677 |
+
if (
|
| 678 |
+
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
| 679 |
+
): # tracking d*lr value
|
| 680 |
+
logs["lr/d*lr"] = (
|
| 681 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
| 682 |
+
)
|
| 683 |
+
accelerator.log(logs, step=global_step)
|
| 684 |
+
|
| 685 |
+
loss_total += current_loss
|
| 686 |
+
avr_loss = loss_total / (step + 1)
|
| 687 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 688 |
+
progress_bar.set_postfix(**logs)
|
| 689 |
+
|
| 690 |
+
if global_step >= args.max_train_steps:
|
| 691 |
+
break
|
| 692 |
+
|
| 693 |
+
if args.logging_dir is not None:
|
| 694 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
| 695 |
+
accelerator.log(logs, step=epoch + 1)
|
| 696 |
+
|
| 697 |
+
accelerator.wait_for_everyone()
|
| 698 |
+
|
| 699 |
+
updated_embs_list = []
|
| 700 |
+
for text_encoder, token_ids in zip(text_encoders, token_ids_list):
|
| 701 |
+
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
| 702 |
+
updated_embs_list.append(updated_embs)
|
| 703 |
+
|
| 704 |
+
if args.save_every_n_epochs is not None:
|
| 705 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 706 |
+
if accelerator.is_main_process and saving:
|
| 707 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 708 |
+
save_model(ckpt_name, updated_embs_list, epoch + 1, global_step)
|
| 709 |
+
|
| 710 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 711 |
+
if remove_epoch_no is not None:
|
| 712 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 713 |
+
remove_model(remove_ckpt_name)
|
| 714 |
+
|
| 715 |
+
if args.save_state:
|
| 716 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 717 |
+
|
| 718 |
+
self.sample_images(
|
| 719 |
+
accelerator,
|
| 720 |
+
args,
|
| 721 |
+
epoch + 1,
|
| 722 |
+
global_step,
|
| 723 |
+
accelerator.device,
|
| 724 |
+
vae,
|
| 725 |
+
tokenizer_or_list,
|
| 726 |
+
text_encoder_or_list,
|
| 727 |
+
unet,
|
| 728 |
+
prompt_replacement,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
# end of epoch
|
| 732 |
+
|
| 733 |
+
is_main_process = accelerator.is_main_process
|
| 734 |
+
if is_main_process:
|
| 735 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
| 736 |
+
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
| 737 |
+
|
| 738 |
+
accelerator.end_training()
|
| 739 |
+
|
| 740 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 741 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 742 |
+
|
| 743 |
+
if is_main_process:
|
| 744 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 745 |
+
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
| 746 |
+
|
| 747 |
+
logger.info("model saved.")
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 751 |
+
parser = argparse.ArgumentParser()
|
| 752 |
+
|
| 753 |
+
add_logging_arguments(parser)
|
| 754 |
+
train_util.add_sd_models_arguments(parser)
|
| 755 |
+
train_util.add_dataset_arguments(parser, True, True, False)
|
| 756 |
+
train_util.add_training_arguments(parser, True)
|
| 757 |
+
train_util.add_masked_loss_arguments(parser)
|
| 758 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 759 |
+
train_util.add_optimizer_arguments(parser)
|
| 760 |
+
config_util.add_config_arguments(parser)
|
| 761 |
+
custom_train_functions.add_custom_train_arguments(parser, False)
|
| 762 |
+
|
| 763 |
+
parser.add_argument(
|
| 764 |
+
"--save_model_as",
|
| 765 |
+
type=str,
|
| 766 |
+
default="pt",
|
| 767 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 768 |
+
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
parser.add_argument(
|
| 772 |
+
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
|
| 773 |
+
)
|
| 774 |
+
parser.add_argument(
|
| 775 |
+
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
| 776 |
+
)
|
| 777 |
+
parser.add_argument(
|
| 778 |
+
"--token_string",
|
| 779 |
+
type=str,
|
| 780 |
+
default=None,
|
| 781 |
+
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
| 782 |
+
)
|
| 783 |
+
parser.add_argument(
|
| 784 |
+
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
|
| 785 |
+
)
|
| 786 |
+
parser.add_argument(
|
| 787 |
+
"--use_object_template",
|
| 788 |
+
action="store_true",
|
| 789 |
+
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
| 790 |
+
)
|
| 791 |
+
parser.add_argument(
|
| 792 |
+
"--use_style_template",
|
| 793 |
+
action="store_true",
|
| 794 |
+
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
| 795 |
+
)
|
| 796 |
+
parser.add_argument(
|
| 797 |
+
"--no_half_vae",
|
| 798 |
+
action="store_true",
|
| 799 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
return parser
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
if __name__ == "__main__":
|
| 806 |
+
parser = setup_parser()
|
| 807 |
+
|
| 808 |
+
args = parser.parse_args()
|
| 809 |
+
train_util.verify_command_line_training_args(args)
|
| 810 |
+
args = train_util.read_config_from_file(args, parser)
|
| 811 |
+
|
| 812 |
+
trainer = TextualInversionTrainer()
|
| 813 |
+
trainer.train(args)
|
train_textual_inversion_XTI.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import toml
|
| 6 |
+
from multiprocessing import Value
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from library import deepspeed_utils
|
| 12 |
+
from library.device_utils import init_ipex, clean_memory_on_device
|
| 13 |
+
|
| 14 |
+
init_ipex()
|
| 15 |
+
|
| 16 |
+
from accelerate.utils import set_seed
|
| 17 |
+
import diffusers
|
| 18 |
+
from diffusers import DDPMScheduler
|
| 19 |
+
import library
|
| 20 |
+
|
| 21 |
+
import library.train_util as train_util
|
| 22 |
+
import library.huggingface_util as huggingface_util
|
| 23 |
+
import library.config_util as config_util
|
| 24 |
+
from library.config_util import (
|
| 25 |
+
ConfigSanitizer,
|
| 26 |
+
BlueprintGenerator,
|
| 27 |
+
)
|
| 28 |
+
import library.custom_train_functions as custom_train_functions
|
| 29 |
+
from library.custom_train_functions import (
|
| 30 |
+
apply_snr_weight,
|
| 31 |
+
prepare_scheduler_for_custom_training,
|
| 32 |
+
pyramid_noise_like,
|
| 33 |
+
apply_noise_offset,
|
| 34 |
+
scale_v_prediction_loss_like_noise_prediction,
|
| 35 |
+
apply_debiased_estimation,
|
| 36 |
+
apply_masked_loss,
|
| 37 |
+
)
|
| 38 |
+
import library.original_unet as original_unet
|
| 39 |
+
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
| 40 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 41 |
+
|
| 42 |
+
setup_logging()
|
| 43 |
+
import logging
|
| 44 |
+
|
| 45 |
+
logger = logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
imagenet_templates_small = [
|
| 48 |
+
"a photo of a {}",
|
| 49 |
+
"a rendering of a {}",
|
| 50 |
+
"a cropped photo of the {}",
|
| 51 |
+
"the photo of a {}",
|
| 52 |
+
"a photo of a clean {}",
|
| 53 |
+
"a photo of a dirty {}",
|
| 54 |
+
"a dark photo of the {}",
|
| 55 |
+
"a photo of my {}",
|
| 56 |
+
"a photo of the cool {}",
|
| 57 |
+
"a close-up photo of a {}",
|
| 58 |
+
"a bright photo of the {}",
|
| 59 |
+
"a cropped photo of a {}",
|
| 60 |
+
"a photo of the {}",
|
| 61 |
+
"a good photo of the {}",
|
| 62 |
+
"a photo of one {}",
|
| 63 |
+
"a close-up photo of the {}",
|
| 64 |
+
"a rendition of the {}",
|
| 65 |
+
"a photo of the clean {}",
|
| 66 |
+
"a rendition of a {}",
|
| 67 |
+
"a photo of a nice {}",
|
| 68 |
+
"a good photo of a {}",
|
| 69 |
+
"a photo of the nice {}",
|
| 70 |
+
"a photo of the small {}",
|
| 71 |
+
"a photo of the weird {}",
|
| 72 |
+
"a photo of the large {}",
|
| 73 |
+
"a photo of a cool {}",
|
| 74 |
+
"a photo of a small {}",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
imagenet_style_templates_small = [
|
| 78 |
+
"a painting in the style of {}",
|
| 79 |
+
"a rendering in the style of {}",
|
| 80 |
+
"a cropped painting in the style of {}",
|
| 81 |
+
"the painting in the style of {}",
|
| 82 |
+
"a clean painting in the style of {}",
|
| 83 |
+
"a dirty painting in the style of {}",
|
| 84 |
+
"a dark painting in the style of {}",
|
| 85 |
+
"a picture in the style of {}",
|
| 86 |
+
"a cool painting in the style of {}",
|
| 87 |
+
"a close-up painting in the style of {}",
|
| 88 |
+
"a bright painting in the style of {}",
|
| 89 |
+
"a cropped painting in the style of {}",
|
| 90 |
+
"a good painting in the style of {}",
|
| 91 |
+
"a close-up painting in the style of {}",
|
| 92 |
+
"a rendition in the style of {}",
|
| 93 |
+
"a nice painting in the style of {}",
|
| 94 |
+
"a small painting in the style of {}",
|
| 95 |
+
"a weird painting in the style of {}",
|
| 96 |
+
"a large painting in the style of {}",
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def train(args):
|
| 101 |
+
if args.output_name is None:
|
| 102 |
+
args.output_name = args.token_string
|
| 103 |
+
use_template = args.use_object_template or args.use_style_template
|
| 104 |
+
setup_logging(args, reset=True)
|
| 105 |
+
|
| 106 |
+
train_util.verify_training_args(args)
|
| 107 |
+
train_util.prepare_dataset_args(args, True)
|
| 108 |
+
|
| 109 |
+
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
| 110 |
+
logger.warning(
|
| 111 |
+
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
| 112 |
+
)
|
| 113 |
+
assert (
|
| 114 |
+
args.dataset_class is None
|
| 115 |
+
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
| 116 |
+
|
| 117 |
+
cache_latents = args.cache_latents
|
| 118 |
+
|
| 119 |
+
if args.seed is not None:
|
| 120 |
+
set_seed(args.seed)
|
| 121 |
+
|
| 122 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 123 |
+
|
| 124 |
+
# acceleratorを準備する
|
| 125 |
+
logger.info("prepare accelerator")
|
| 126 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 127 |
+
|
| 128 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 129 |
+
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
| 130 |
+
|
| 131 |
+
# モデルを読み込む
|
| 132 |
+
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 133 |
+
|
| 134 |
+
# Convert the init_word to token_id
|
| 135 |
+
if args.init_word is not None:
|
| 136 |
+
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
| 137 |
+
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
| 138 |
+
logger.warning(
|
| 139 |
+
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
init_token_ids = None
|
| 143 |
+
|
| 144 |
+
# add new word to tokenizer, count is num_vectors_per_token
|
| 145 |
+
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
| 146 |
+
num_added_tokens = tokenizer.add_tokens(token_strings)
|
| 147 |
+
assert (
|
| 148 |
+
num_added_tokens == args.num_vectors_per_token
|
| 149 |
+
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
| 150 |
+
|
| 151 |
+
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
| 152 |
+
logger.info(f"tokens are added: {token_ids}")
|
| 153 |
+
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
| 154 |
+
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
| 155 |
+
|
| 156 |
+
token_strings_XTI = []
|
| 157 |
+
XTI_layers = [
|
| 158 |
+
"IN01",
|
| 159 |
+
"IN02",
|
| 160 |
+
"IN04",
|
| 161 |
+
"IN05",
|
| 162 |
+
"IN07",
|
| 163 |
+
"IN08",
|
| 164 |
+
"MID",
|
| 165 |
+
"OUT03",
|
| 166 |
+
"OUT04",
|
| 167 |
+
"OUT05",
|
| 168 |
+
"OUT06",
|
| 169 |
+
"OUT07",
|
| 170 |
+
"OUT08",
|
| 171 |
+
"OUT09",
|
| 172 |
+
"OUT10",
|
| 173 |
+
"OUT11",
|
| 174 |
+
]
|
| 175 |
+
for layer_name in XTI_layers:
|
| 176 |
+
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
|
| 177 |
+
|
| 178 |
+
tokenizer.add_tokens(token_strings_XTI)
|
| 179 |
+
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
| 180 |
+
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
|
| 181 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
| 182 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 183 |
+
|
| 184 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
| 185 |
+
token_embeds = text_encoder.get_input_embeddings().weight.data
|
| 186 |
+
if init_token_ids is not None:
|
| 187 |
+
for i, token_id in enumerate(token_ids_XTI):
|
| 188 |
+
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
| 189 |
+
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
| 190 |
+
|
| 191 |
+
# load weights
|
| 192 |
+
if args.weights is not None:
|
| 193 |
+
embeddings = load_weights(args.weights)
|
| 194 |
+
assert len(token_ids) == len(
|
| 195 |
+
embeddings
|
| 196 |
+
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
| 197 |
+
# logger.info(token_ids, embeddings.size())
|
| 198 |
+
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
| 199 |
+
token_embeds[token_id] = embedding
|
| 200 |
+
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
| 201 |
+
logger.info(f"weighs loaded")
|
| 202 |
+
|
| 203 |
+
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
| 204 |
+
|
| 205 |
+
# データセットを準備する
|
| 206 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
|
| 207 |
+
if args.dataset_config is not None:
|
| 208 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 209 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 210 |
+
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
| 211 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 212 |
+
logger.info(
|
| 213 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 214 |
+
", ".join(ignored)
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
use_dreambooth_method = args.in_json is None
|
| 219 |
+
if use_dreambooth_method:
|
| 220 |
+
logger.info("Use DreamBooth method.")
|
| 221 |
+
user_config = {
|
| 222 |
+
"datasets": [
|
| 223 |
+
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
| 224 |
+
]
|
| 225 |
+
}
|
| 226 |
+
else:
|
| 227 |
+
logger.info("Train with captions.")
|
| 228 |
+
user_config = {
|
| 229 |
+
"datasets": [
|
| 230 |
+
{
|
| 231 |
+
"subsets": [
|
| 232 |
+
{
|
| 233 |
+
"image_dir": args.train_data_dir,
|
| 234 |
+
"metadata_file": args.in_json,
|
| 235 |
+
}
|
| 236 |
+
]
|
| 237 |
+
}
|
| 238 |
+
]
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
| 242 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 243 |
+
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
| 244 |
+
current_epoch = Value("i", 0)
|
| 245 |
+
current_step = Value("i", 0)
|
| 246 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 247 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 248 |
+
|
| 249 |
+
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
| 250 |
+
if use_template:
|
| 251 |
+
logger.info(f"use template for training captions. is object: {args.use_object_template}")
|
| 252 |
+
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
| 253 |
+
replace_to = " ".join(token_strings)
|
| 254 |
+
captions = []
|
| 255 |
+
for tmpl in templates:
|
| 256 |
+
captions.append(tmpl.format(replace_to))
|
| 257 |
+
train_dataset_group.add_replacement("", captions)
|
| 258 |
+
|
| 259 |
+
if args.num_vectors_per_token > 1:
|
| 260 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 261 |
+
else:
|
| 262 |
+
prompt_replacement = None
|
| 263 |
+
else:
|
| 264 |
+
if args.num_vectors_per_token > 1:
|
| 265 |
+
replace_to = " ".join(token_strings)
|
| 266 |
+
train_dataset_group.add_replacement(args.token_string, replace_to)
|
| 267 |
+
prompt_replacement = (args.token_string, replace_to)
|
| 268 |
+
else:
|
| 269 |
+
prompt_replacement = None
|
| 270 |
+
|
| 271 |
+
if args.debug_dataset:
|
| 272 |
+
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
| 273 |
+
return
|
| 274 |
+
if len(train_dataset_group) == 0:
|
| 275 |
+
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
if cache_latents:
|
| 279 |
+
assert (
|
| 280 |
+
train_dataset_group.is_latent_cacheable()
|
| 281 |
+
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
| 282 |
+
|
| 283 |
+
# モデルに xformers とか memory efficient attention を組み込む
|
| 284 |
+
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
| 285 |
+
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
|
| 286 |
+
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
| 287 |
+
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
| 288 |
+
|
| 289 |
+
# 学習を準備する
|
| 290 |
+
if cache_latents:
|
| 291 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 292 |
+
vae.requires_grad_(False)
|
| 293 |
+
vae.eval()
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
| 296 |
+
vae.to("cpu")
|
| 297 |
+
clean_memory_on_device(accelerator.device)
|
| 298 |
+
|
| 299 |
+
accelerator.wait_for_everyone()
|
| 300 |
+
|
| 301 |
+
if args.gradient_checkpointing:
|
| 302 |
+
unet.enable_gradient_checkpointing()
|
| 303 |
+
text_encoder.gradient_checkpointing_enable()
|
| 304 |
+
|
| 305 |
+
# 学習に必要なクラスを準備する
|
| 306 |
+
logger.info("prepare optimizer, data loader etc.")
|
| 307 |
+
trainable_params = text_encoder.get_input_embeddings().parameters()
|
| 308 |
+
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
| 309 |
+
|
| 310 |
+
# dataloaderを準備する
|
| 311 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 312 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 313 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 314 |
+
train_dataset_group,
|
| 315 |
+
batch_size=1,
|
| 316 |
+
shuffle=True,
|
| 317 |
+
collate_fn=collator,
|
| 318 |
+
num_workers=n_workers,
|
| 319 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# 学習ステップ数を計算する
|
| 323 |
+
if args.max_train_epochs is not None:
|
| 324 |
+
args.max_train_steps = args.max_train_epochs * math.ceil(
|
| 325 |
+
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
| 326 |
+
)
|
| 327 |
+
logger.info(
|
| 328 |
+
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# データセット側にも学習ステップを送信
|
| 332 |
+
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
| 333 |
+
|
| 334 |
+
# lr schedulerを用意する
|
| 335 |
+
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
| 336 |
+
|
| 337 |
+
# acceleratorがなんかよろしくやってくれるらしい
|
| 338 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 339 |
+
text_encoder, optimizer, train_dataloader, lr_scheduler
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
| 343 |
+
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
| 344 |
+
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
| 345 |
+
|
| 346 |
+
# Freeze all parameters except for the token embeddings in text encoder
|
| 347 |
+
text_encoder.requires_grad_(True)
|
| 348 |
+
text_encoder.text_model.encoder.requires_grad_(False)
|
| 349 |
+
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
| 350 |
+
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
| 351 |
+
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
| 352 |
+
|
| 353 |
+
unet.requires_grad_(False)
|
| 354 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 355 |
+
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
| 356 |
+
unet.train()
|
| 357 |
+
else:
|
| 358 |
+
unet.eval()
|
| 359 |
+
|
| 360 |
+
if not cache_latents:
|
| 361 |
+
vae.requires_grad_(False)
|
| 362 |
+
vae.eval()
|
| 363 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 364 |
+
|
| 365 |
+
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
| 366 |
+
if args.full_fp16:
|
| 367 |
+
train_util.patch_accelerator_for_fp16_training(accelerator)
|
| 368 |
+
text_encoder.to(weight_dtype)
|
| 369 |
+
|
| 370 |
+
# resumeする
|
| 371 |
+
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
| 372 |
+
|
| 373 |
+
# epoch数を計算する
|
| 374 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 375 |
+
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 376 |
+
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
| 377 |
+
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
| 378 |
+
|
| 379 |
+
# 学習する
|
| 380 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 381 |
+
logger.info("running training / 学習開始")
|
| 382 |
+
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
| 383 |
+
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
| 384 |
+
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
| 385 |
+
logger.info(f" num epochs / epoch数: {num_train_epochs}")
|
| 386 |
+
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
| 387 |
+
logger.info(
|
| 388 |
+
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
| 389 |
+
)
|
| 390 |
+
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
| 391 |
+
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
| 392 |
+
|
| 393 |
+
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
| 394 |
+
global_step = 0
|
| 395 |
+
|
| 396 |
+
noise_scheduler = DDPMScheduler(
|
| 397 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
| 398 |
+
)
|
| 399 |
+
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
| 400 |
+
if args.zero_terminal_snr:
|
| 401 |
+
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
| 402 |
+
|
| 403 |
+
if accelerator.is_main_process:
|
| 404 |
+
init_kwargs = {}
|
| 405 |
+
if args.wandb_run_name:
|
| 406 |
+
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
| 407 |
+
if args.log_tracker_config is not None:
|
| 408 |
+
init_kwargs = toml.load(args.log_tracker_config)
|
| 409 |
+
accelerator.init_trackers(
|
| 410 |
+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# function for saving/removing
|
| 414 |
+
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
| 415 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 416 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
| 417 |
+
|
| 418 |
+
logger.info("")
|
| 419 |
+
logger.info(f"saving checkpoint: {ckpt_file}")
|
| 420 |
+
save_weights(ckpt_file, embs, save_dtype)
|
| 421 |
+
if args.huggingface_repo_id is not None:
|
| 422 |
+
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
| 423 |
+
|
| 424 |
+
def remove_model(old_ckpt_name):
|
| 425 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
| 426 |
+
if os.path.exists(old_ckpt_file):
|
| 427 |
+
logger.info(f"removing old checkpoint: {old_ckpt_file}")
|
| 428 |
+
os.remove(old_ckpt_file)
|
| 429 |
+
|
| 430 |
+
# training loop
|
| 431 |
+
for epoch in range(num_train_epochs):
|
| 432 |
+
logger.info("")
|
| 433 |
+
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
|
| 434 |
+
current_epoch.value = epoch + 1
|
| 435 |
+
|
| 436 |
+
text_encoder.train()
|
| 437 |
+
|
| 438 |
+
loss_total = 0
|
| 439 |
+
|
| 440 |
+
for step, batch in enumerate(train_dataloader):
|
| 441 |
+
current_step.value = global_step
|
| 442 |
+
with accelerator.accumulate(text_encoder):
|
| 443 |
+
with torch.no_grad():
|
| 444 |
+
if "latents" in batch and batch["latents"] is not None:
|
| 445 |
+
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
| 446 |
+
else:
|
| 447 |
+
# latentに変換
|
| 448 |
+
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 449 |
+
latents = latents * 0.18215
|
| 450 |
+
b_size = latents.shape[0]
|
| 451 |
+
|
| 452 |
+
# Get the text embedding for conditioning
|
| 453 |
+
input_ids = batch["input_ids"].to(accelerator.device)
|
| 454 |
+
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
| 455 |
+
encoder_hidden_states = torch.stack(
|
| 456 |
+
[
|
| 457 |
+
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
|
| 458 |
+
for s in torch.split(input_ids, 1, dim=1)
|
| 459 |
+
]
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
| 463 |
+
# with noise offset and/or multires noise if specified
|
| 464 |
+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
| 465 |
+
|
| 466 |
+
# Predict the noise residual
|
| 467 |
+
with accelerator.autocast():
|
| 468 |
+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
|
| 469 |
+
|
| 470 |
+
if args.v_parameterization:
|
| 471 |
+
# v-parameterization training
|
| 472 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 473 |
+
else:
|
| 474 |
+
target = noise
|
| 475 |
+
|
| 476 |
+
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
|
| 477 |
+
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
| 478 |
+
loss = apply_masked_loss(loss, batch)
|
| 479 |
+
loss = loss.mean([1, 2, 3])
|
| 480 |
+
|
| 481 |
+
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
| 482 |
+
|
| 483 |
+
loss = loss * loss_weights
|
| 484 |
+
if args.min_snr_gamma:
|
| 485 |
+
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
| 486 |
+
if args.scale_v_pred_loss_like_noise_pred:
|
| 487 |
+
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
| 488 |
+
if args.debiased_estimation_loss:
|
| 489 |
+
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
|
| 490 |
+
|
| 491 |
+
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
| 492 |
+
|
| 493 |
+
accelerator.backward(loss)
|
| 494 |
+
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
| 495 |
+
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
| 496 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 497 |
+
|
| 498 |
+
optimizer.step()
|
| 499 |
+
lr_scheduler.step()
|
| 500 |
+
optimizer.zero_grad(set_to_none=True)
|
| 501 |
+
|
| 502 |
+
# Let's make sure we don't update any embedding weights besides the newly added token
|
| 503 |
+
with torch.no_grad():
|
| 504 |
+
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
|
| 505 |
+
index_no_updates
|
| 506 |
+
]
|
| 507 |
+
|
| 508 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 509 |
+
if accelerator.sync_gradients:
|
| 510 |
+
progress_bar.update(1)
|
| 511 |
+
global_step += 1
|
| 512 |
+
# TODO: fix sample_images
|
| 513 |
+
# train_util.sample_images(
|
| 514 |
+
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
| 515 |
+
# )
|
| 516 |
+
|
| 517 |
+
# 指定ステップごとにモデルを保存
|
| 518 |
+
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
| 519 |
+
accelerator.wait_for_everyone()
|
| 520 |
+
if accelerator.is_main_process:
|
| 521 |
+
updated_embs = (
|
| 522 |
+
accelerator.unwrap_model(text_encoder)
|
| 523 |
+
.get_input_embeddings()
|
| 524 |
+
.weight[token_ids_XTI]
|
| 525 |
+
.data.detach()
|
| 526 |
+
.clone()
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
|
| 530 |
+
save_model(ckpt_name, updated_embs, global_step, epoch)
|
| 531 |
+
|
| 532 |
+
if args.save_state:
|
| 533 |
+
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
|
| 534 |
+
|
| 535 |
+
remove_step_no = train_util.get_remove_step_no(args, global_step)
|
| 536 |
+
if remove_step_no is not None:
|
| 537 |
+
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
|
| 538 |
+
remove_model(remove_ckpt_name)
|
| 539 |
+
|
| 540 |
+
current_loss = loss.detach().item()
|
| 541 |
+
if args.logging_dir is not None:
|
| 542 |
+
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
| 543 |
+
if (
|
| 544 |
+
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
| 545 |
+
): # tracking d*lr value
|
| 546 |
+
logs["lr/d*lr"] = (
|
| 547 |
+
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
| 548 |
+
)
|
| 549 |
+
accelerator.log(logs, step=global_step)
|
| 550 |
+
|
| 551 |
+
loss_total += current_loss
|
| 552 |
+
avr_loss = loss_total / (step + 1)
|
| 553 |
+
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
| 554 |
+
progress_bar.set_postfix(**logs)
|
| 555 |
+
|
| 556 |
+
if global_step >= args.max_train_steps:
|
| 557 |
+
break
|
| 558 |
+
|
| 559 |
+
if args.logging_dir is not None:
|
| 560 |
+
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
| 561 |
+
accelerator.log(logs, step=epoch + 1)
|
| 562 |
+
|
| 563 |
+
accelerator.wait_for_everyone()
|
| 564 |
+
|
| 565 |
+
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
| 566 |
+
|
| 567 |
+
if args.save_every_n_epochs is not None:
|
| 568 |
+
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
|
| 569 |
+
if accelerator.is_main_process and saving:
|
| 570 |
+
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
|
| 571 |
+
save_model(ckpt_name, updated_embs, epoch + 1, global_step)
|
| 572 |
+
|
| 573 |
+
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
| 574 |
+
if remove_epoch_no is not None:
|
| 575 |
+
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
|
| 576 |
+
remove_model(remove_ckpt_name)
|
| 577 |
+
|
| 578 |
+
if args.save_state:
|
| 579 |
+
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
|
| 580 |
+
|
| 581 |
+
# TODO: fix sample_images
|
| 582 |
+
# train_util.sample_images(
|
| 583 |
+
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
| 584 |
+
# )
|
| 585 |
+
|
| 586 |
+
# end of epoch
|
| 587 |
+
|
| 588 |
+
is_main_process = accelerator.is_main_process
|
| 589 |
+
if is_main_process:
|
| 590 |
+
text_encoder = accelerator.unwrap_model(text_encoder)
|
| 591 |
+
|
| 592 |
+
accelerator.end_training()
|
| 593 |
+
|
| 594 |
+
if is_main_process and (args.save_state or args.save_state_on_train_end):
|
| 595 |
+
train_util.save_state_on_train_end(args, accelerator)
|
| 596 |
+
|
| 597 |
+
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
|
| 598 |
+
|
| 599 |
+
del accelerator # この後メモリを使うのでこれは消す
|
| 600 |
+
|
| 601 |
+
if is_main_process:
|
| 602 |
+
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
| 603 |
+
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
| 604 |
+
|
| 605 |
+
logger.info("model saved.")
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def save_weights(file, updated_embs, save_dtype):
|
| 609 |
+
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
|
| 610 |
+
updated_embs = updated_embs.chunk(16)
|
| 611 |
+
XTI_layers = [
|
| 612 |
+
"IN01",
|
| 613 |
+
"IN02",
|
| 614 |
+
"IN04",
|
| 615 |
+
"IN05",
|
| 616 |
+
"IN07",
|
| 617 |
+
"IN08",
|
| 618 |
+
"MID",
|
| 619 |
+
"OUT03",
|
| 620 |
+
"OUT04",
|
| 621 |
+
"OUT05",
|
| 622 |
+
"OUT06",
|
| 623 |
+
"OUT07",
|
| 624 |
+
"OUT08",
|
| 625 |
+
"OUT09",
|
| 626 |
+
"OUT10",
|
| 627 |
+
"OUT11",
|
| 628 |
+
]
|
| 629 |
+
state_dict = {}
|
| 630 |
+
for i, layer_name in enumerate(XTI_layers):
|
| 631 |
+
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
|
| 632 |
+
|
| 633 |
+
# if save_dtype is not None:
|
| 634 |
+
# for key in list(state_dict.keys()):
|
| 635 |
+
# v = state_dict[key]
|
| 636 |
+
# v = v.detach().clone().to("cpu").to(save_dtype)
|
| 637 |
+
# state_dict[key] = v
|
| 638 |
+
|
| 639 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 640 |
+
from safetensors.torch import save_file
|
| 641 |
+
|
| 642 |
+
save_file(state_dict, file)
|
| 643 |
+
else:
|
| 644 |
+
torch.save(state_dict, file) # can be loaded in Web UI
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def load_weights(file):
|
| 648 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 649 |
+
from safetensors.torch import load_file
|
| 650 |
+
|
| 651 |
+
data = load_file(file)
|
| 652 |
+
else:
|
| 653 |
+
raise ValueError(f"NOT XTI: {file}")
|
| 654 |
+
|
| 655 |
+
if len(data.values()) != 16:
|
| 656 |
+
raise ValueError(f"NOT XTI: {file}")
|
| 657 |
+
|
| 658 |
+
emb = torch.concat([x for x in data.values()])
|
| 659 |
+
|
| 660 |
+
return emb
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 664 |
+
parser = argparse.ArgumentParser()
|
| 665 |
+
|
| 666 |
+
add_logging_arguments(parser)
|
| 667 |
+
train_util.add_sd_models_arguments(parser)
|
| 668 |
+
train_util.add_dataset_arguments(parser, True, True, False)
|
| 669 |
+
train_util.add_training_arguments(parser, True)
|
| 670 |
+
train_util.add_masked_loss_arguments(parser)
|
| 671 |
+
deepspeed_utils.add_deepspeed_arguments(parser)
|
| 672 |
+
train_util.add_optimizer_arguments(parser)
|
| 673 |
+
config_util.add_config_arguments(parser)
|
| 674 |
+
custom_train_functions.add_custom_train_arguments(parser, False)
|
| 675 |
+
|
| 676 |
+
parser.add_argument(
|
| 677 |
+
"--save_model_as",
|
| 678 |
+
type=str,
|
| 679 |
+
default="pt",
|
| 680 |
+
choices=[None, "ckpt", "pt", "safetensors"],
|
| 681 |
+
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
parser.add_argument(
|
| 685 |
+
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
|
| 686 |
+
)
|
| 687 |
+
parser.add_argument(
|
| 688 |
+
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
| 689 |
+
)
|
| 690 |
+
parser.add_argument(
|
| 691 |
+
"--token_string",
|
| 692 |
+
type=str,
|
| 693 |
+
default=None,
|
| 694 |
+
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
| 695 |
+
)
|
| 696 |
+
parser.add_argument(
|
| 697 |
+
"--init_word", type=str, default=None, help="words to initialize vector / ベ���トルを初期化に使用する単語、複数可"
|
| 698 |
+
)
|
| 699 |
+
parser.add_argument(
|
| 700 |
+
"--use_object_template",
|
| 701 |
+
action="store_true",
|
| 702 |
+
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
|
| 703 |
+
)
|
| 704 |
+
parser.add_argument(
|
| 705 |
+
"--use_style_template",
|
| 706 |
+
action="store_true",
|
| 707 |
+
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
return parser
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
if __name__ == "__main__":
|
| 714 |
+
parser = setup_parser()
|
| 715 |
+
|
| 716 |
+
args = parser.parse_args()
|
| 717 |
+
train_util.verify_command_line_training_args(args)
|
| 718 |
+
args = train_util.read_config_from_file(args, parser)
|
| 719 |
+
|
| 720 |
+
train(args)
|