Allex21 commited on
Commit
eac965b
·
verified ·
1 Parent(s): 7231b17

Upload 24 files

Browse files
LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2022] [kohya-ss]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANUAL_DE_USO.md ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📖 Manual de Uso - LoRA Trainer Funcional
2
+
3
+ ## 🎯 Visão Geral
4
+
5
+ Este LoRA Trainer é uma ferramenta **100% funcional** baseada no kohya-ss sd-scripts que permite treinar modelos LoRA reais para Stable Diffusion. A ferramenta foi desenvolvida especificamente para funcionar no Hugging Face Spaces e oferece todas as funcionalidades necessárias para um treinamento profissional.
6
+
7
+ ## 🚀 Início Rápido
8
+
9
+ ### Passo 1: Instalação das Dependências
10
+ 1. Acesse a aba "🔧 Instalação"
11
+ 2. Clique em "📦 Instalar Dependências"
12
+ 3. Aguarde a instalação completa (pode levar alguns minutos)
13
+
14
+ ### Passo 2: Configuração do Projeto
15
+ 1. Vá para a aba "📁 Configuração do Projeto"
16
+ 2. Digite um nome único para seu projeto (ex: "meu_lora_anime")
17
+ 3. Escolha um modelo base ou insira uma URL personalizada
18
+ 4. Clique em "📥 Baixar Modelo"
19
+
20
+ ### Passo 3: Preparação do Dataset
21
+ 1. Organize suas imagens em uma pasta local
22
+ 2. Para cada imagem, crie um arquivo .txt com o mesmo nome
23
+ 3. Compacte tudo em um arquivo ZIP
24
+ 4. Faça upload na seção "📊 Upload do Dataset"
25
+ 5. Clique em "📊 Processar Dataset"
26
+
27
+ ### Passo 4: Configuração dos Parâmetros
28
+ 1. Acesse a aba "⚙️ Parâmetros de Treinamento"
29
+ 2. Ajuste os parâmetros conforme sua necessidade
30
+ 3. Use as configurações recomendadas como ponto de partida
31
+
32
+ ### Passo 5: Treinamento
33
+ 1. Vá para a aba "🚀 Treinamento"
34
+ 2. Clique em "📝 Criar Configuração de Treinamento"
35
+ 3. Clique em "🎯 Iniciar Treinamento"
36
+ 4. Acompanhe o progresso em tempo real
37
+
38
+ ### Passo 6: Download dos Resultados
39
+ 1. Acesse a aba "📥 Download dos Resultados"
40
+ 2. Clique em "🔄 Atualizar Lista de Arquivos"
41
+ 3. Selecione e baixe seu LoRA treinado
42
+
43
+ ## 📋 Requisitos do Sistema
44
+
45
+ ### Mínimos
46
+ - **GPU**: NVIDIA com 6GB VRAM
47
+ - **RAM**: 8GB
48
+ - **Espaço**: 5GB livres
49
+
50
+ ### Recomendados
51
+ - **GPU**: NVIDIA com 12GB+ VRAM
52
+ - **RAM**: 16GB+
53
+ - **Espaço**: 20GB+ livres
54
+
55
+ ## 🎨 Preparação do Dataset
56
+
57
+ ### Estrutura Recomendada
58
+ ```
59
+ meu_dataset/
60
+ ├── imagem001.jpg
61
+ ├── imagem001.txt
62
+ ├── imagem002.png
63
+ ├── imagem002.txt
64
+ ├── imagem003.webp
65
+ ├── imagem003.txt
66
+ └── ...
67
+ ```
68
+
69
+ ### Formatos Suportados
70
+ - **Imagens**: JPG, PNG, WEBP, BMP, TIFF
71
+ - **Captions**: TXT (UTF-8)
72
+
73
+ ### Exemplo de Caption
74
+ ```
75
+ 1girl, long hair, blue eyes, school uniform, smile, outdoors, cherry blossoms, anime style, high quality
76
+ ```
77
+
78
+ ### Dicas para Captions
79
+ - Use vírgulas para separar tags
80
+ - Coloque tags importantes no início
81
+ - Seja específico mas conciso
82
+ - Use tags consistentes em todo o dataset
83
+
84
+ ## ⚙️ Configuração de Parâmetros
85
+
86
+ ### Parâmetros Básicos
87
+
88
+ #### Resolução
89
+ - **512px**: Padrão, mais rápido, menor uso de memória
90
+ - **768px**: Melhor qualidade, moderado
91
+ - **1024px**: Máxima qualidade, mais lento
92
+
93
+ #### Batch Size
94
+ - **1**: Menor uso de memória, mais lento
95
+ - **2-4**: Equilibrado (recomendado)
96
+ - **8+**: Apenas para GPUs potentes
97
+
98
+ #### Épocas
99
+ - **5-10**: Para datasets grandes (50+ imagens)
100
+ - **10-20**: Para datasets médios (20-50 imagens)
101
+ - **20-30**: Para datasets pequenos (10-20 imagens)
102
+
103
+ ### Parâmetros Avançados
104
+
105
+ #### Learning Rate
106
+ - **1e-3**: Muito alto, pode causar instabilidade
107
+ - **5e-4**: Padrão, bom para a maioria dos casos
108
+ - **1e-4**: Conservador, para datasets grandes
109
+ - **5e-5**: Muito baixo, treinamento lento
110
+
111
+ #### Network Dimension
112
+ - **8-16**: LoRAs pequenos, menos detalhes
113
+ - **32**: Padrão, bom equilíbrio
114
+ - **64-128**: Mais detalhes, arquivos maiores
115
+
116
+ #### Network Alpha
117
+ - Geralmente metade do Network Dimension
118
+ - Controla a força do LoRA
119
+ - Valores menores = efeito mais sutil
120
+
121
+ ### Tipos de LoRA
122
+
123
+ #### LoRA Clássico
124
+ - Menor tamanho de arquivo
125
+ - Bom para uso geral
126
+ - Mais rápido para treinar
127
+
128
+ #### LoCon
129
+ - Melhor para estilos artísticos
130
+ - Mais camadas de aprendizado
131
+ - Arquivos maiores
132
+
133
+ ## 🎯 Configurações por Tipo de Projeto
134
+
135
+ ### Para Personagens/Pessoas
136
+ ```
137
+ Imagens: 15-30 variadas
138
+ Network Dim: 32
139
+ Network Alpha: 16
140
+ Learning Rate: 1e-4
141
+ Épocas: 10-15
142
+ Batch Size: 2
143
+ ```
144
+
145
+ ### Para Estilos Artísticos
146
+ ```
147
+ Imagens: 30-50 do estilo
148
+ Tipo: LoCon
149
+ Network Dim: 64
150
+ Network Alpha: 32
151
+ Learning Rate: 5e-5
152
+ Épocas: 15-25
153
+ Batch Size: 1-2
154
+ ```
155
+
156
+ ### Para Objetos/Conceitos
157
+ ```
158
+ Imagens: 10-25
159
+ Network Dim: 16
160
+ Network Alpha: 8
161
+ Learning Rate: 5e-4
162
+ Épocas: 8-12
163
+ Batch Size: 2-4
164
+ ```
165
+
166
+ ## 🔧 Solução de Problemas
167
+
168
+ ### Erro de Memória (CUDA OOM)
169
+ **Sintomas**: "CUDA out of memory"
170
+ **Soluções**:
171
+ - Reduza o batch size para 1
172
+ - Diminua a resolução para 512px
173
+ - Use mixed precision fp16
174
+
175
+ ### Treinamento Muito Lento
176
+ **Sintomas**: Progresso muito lento
177
+ **Soluções**:
178
+ - Aumente o batch size (se possível)
179
+ - Use resolução menor
180
+ - Verifique se xFormers está ativo
181
+
182
+ ### Resultados Ruins/Overfitting
183
+ **Sintomas**: LoRA não funciona ou muito forte
184
+ **Soluções**:
185
+ - Reduza o learning rate
186
+ - Diminua o número de épocas
187
+ - Use mais imagens variadas
188
+ - Ajuste network alpha
189
+
190
+ ### Erro de Configuração
191
+ **Sintomas**: Falha ao criar configuração
192
+ **Soluções**:
193
+ - Verifique se o modelo foi baixado
194
+ - Confirme que o dataset foi processado
195
+ - Reinicie a aplicação
196
+
197
+ ## 📊 Monitoramento do Treinamento
198
+
199
+ ### Métricas Importantes
200
+ - **Loss**: Deve diminuir gradualmente
201
+ - **Learning Rate**: Varia conforme scheduler
202
+ - **Tempo por Época**: Depende do hardware
203
+
204
+ ### Sinais de Bom Treinamento
205
+ - Loss diminui consistentemente
206
+ - Sem erros de memória
207
+ - Progresso estável
208
+
209
+ ### Sinais de Problemas
210
+ - Loss oscila muito
211
+ - Erros frequentes
212
+ - Progresso muito lento
213
+
214
+ ## 💾 Gerenciamento de Arquivos
215
+
216
+ ### Estrutura de Saída
217
+ ```
218
+ /tmp/lora_training/projects/meu_projeto/
219
+ ├── dataset/ # Imagens processadas
220
+ ├── output/ # LoRAs gerados
221
+ ├── logs/ # Logs do treinamento
222
+ ├── dataset_config.toml
223
+ └── training_config.toml
224
+ ```
225
+
226
+ ### Arquivos Gerados
227
+ - **projeto_epoch_0001.safetensors**: LoRA da época 1
228
+ - **projeto_epoch_0010.safetensors**: LoRA da época 10
229
+ - **logs/**: Logs detalhados do TensorBoard
230
+
231
+ ## 🎨 Uso dos LoRAs Treinados
232
+
233
+ ### No Automatic1111
234
+ 1. Copie o arquivo .safetensors para `models/Lora/`
235
+ 2. Use na prompt: `<lora:nome_do_arquivo:0.8>`
236
+ 3. Ajuste o peso (0.1 a 1.5)
237
+
238
+ ### No ComfyUI
239
+ 1. Coloque o arquivo em `models/loras/`
240
+ 2. Use o nó "Load LoRA"
241
+ 3. Conecte ao modelo
242
+
243
+ ### Pesos Recomendados
244
+ - **0.3-0.6**: Efeito sutil
245
+ - **0.7-1.0**: Efeito padrão
246
+ - **1.1-1.5**: Efeito forte
247
+
248
+ ## 🔄 Melhores Práticas
249
+
250
+ ### Antes do Treinamento
251
+ 1. **Qualidade sobre Quantidade**: 20 imagens boas > 100 ruins
252
+ 2. **Variedade**: Use ângulos, poses e cenários diferentes
253
+ 3. **Consistência**: Mantenha estilo consistente nas captions
254
+ 4. **Backup**: Salve configurações que funcionaram
255
+
256
+ ### Durante o Treinamento
257
+ 1. **Monitore**: Acompanhe o progresso regularmente
258
+ 2. **Paciência**: Não interrompa sem necessidade
259
+ 3. **Recursos**: Monitore uso de GPU/RAM
260
+
261
+ ### Após o Treinamento
262
+ 1. **Teste**: Experimente diferentes pesos
263
+ 2. **Compare**: Teste épocas diferentes
264
+ 3. **Documente**: Anote configurações que funcionaram
265
+ 4. **Compartilhe**: Considere compartilhar bons resultados
266
+
267
+ ## 🆘 Suporte e Recursos
268
+
269
+ ### Documentação Adicional
270
+ - [Guia Oficial kohya-ss](https://github.com/kohya-ss/sd-scripts)
271
+ - [Documentação Diffusers](https://huggingface.co/docs/diffusers)
272
+ - [Comunidade Stable Diffusion](https://discord.gg/stable-diffusion)
273
+
274
+ ### Logs e Debug
275
+ - Verifique os logs em `/tmp/lora_training/projects/seu_projeto/logs/`
276
+ - Use TensorBoard para visualizar métricas
277
+ - Salve configurações que funcionaram bem
278
+
279
+ ### Limitações Conhecidas
280
+ - Requer GPU NVIDIA com CUDA
281
+ - Modelos grandes podem precisar de mais memória
282
+ - Treinamento pode ser lento em hardware limitado
283
+
284
+ ---
285
+
286
+ **Nota**: Esta ferramenta é para fins educacionais e de pesquisa. Use responsavelmente e respeite direitos autorais das imagens utilizadas.
287
+
README-ja.md ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## リポジトリについて
2
+ Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
3
+
4
+ [README in English](./README.md) ←更新情報はこちらにあります
5
+
6
+ 開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
7
+
8
+ FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
9
+
10
+ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
11
+
12
+ 以下のスクリプトがあります。
13
+
14
+ * DreamBooth、U-NetおよびText Encoderの学習をサポート
15
+ * fine-tuning、同上
16
+ * LoRAの学習をサポート
17
+ * 画像生成
18
+ * モデル変換(Stable Diffision ckpt/safetensorsとDiffusersの相互変換)
19
+
20
+ ### スポンサー
21
+
22
+ このプロジェクトを支援してくださる企業・団体の皆様に深く感謝いたします。
23
+
24
+ <a href="https://aihub.co.jp/">
25
+ <img src="./images/logo_aihub.png" alt="AiHUB株式会社" title="AiHUB株式会社" height="100px">
26
+ </a>
27
+
28
+ ### スポンサー募集のお知らせ
29
+
30
+ このプロジェクトがお役に立ったなら、ご支援いただけると嬉しく思います。 [GitHub Sponsors](https://github.com/sponsors/kohya-ss/)で受け付けています。
31
+
32
+ ## 使用法について
33
+
34
+ * [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
35
+ * [データセット設定](./docs/config_README-ja.md)
36
+ * [SDXL学習](./docs/train_SDXL-en.md) (英語版)
37
+ * [DreamBoothの学習について](./docs/train_db_README-ja.md)
38
+ * [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
39
+ * [LoRAの学習について](./docs/train_network_README-ja.md)
40
+ * [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
41
+ * [画像生成スクリプト](./docs/gen_img_README-ja.md)
42
+ * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
43
+
44
+ ## Windowsでの動作に必要なプログラム
45
+
46
+ Python 3.10.6およびGitが必要です。
47
+
48
+ - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
49
+ - git: https://git-scm.com/download/win
50
+
51
+ Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
52
+
53
+ PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
54
+ (venvに限らずスクリプトの実行が可能になりますので注意してください。)
55
+
56
+ - PowerShellを管理者として開きます。
57
+ - 「Set-ExecutionPolicy Unrestricted」と入力し、Yと答えます。
58
+ - 管理者のPowerShellを閉じます。
59
+
60
+ ## Windows環境でのインストール
61
+
62
+ スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。
63
+
64
+ (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。)
65
+
66
+ PowerShellを使う場合、通常の(管理者ではない)PowerShellを開き以下を順に実行します。
67
+
68
+ ```powershell
69
+ git clone https://github.com/kohya-ss/sd-scripts.git
70
+ cd sd-scripts
71
+
72
+ python -m venv venv
73
+ .\venv\Scripts\activate
74
+
75
+ pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
76
+ pip install --upgrade -r requirements.txt
77
+ pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
78
+
79
+ accelerate config
80
+ ```
81
+
82
+ コマンドプロンプトでも同一です。
83
+
84
+ 注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
85
+
86
+ この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
87
+
88
+ PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
89
+
90
+ accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
91
+
92
+ ```txt
93
+ - This machine
94
+ - No distributed training
95
+ - NO
96
+ - NO
97
+ - NO
98
+ - all
99
+ - fp16
100
+ ```
101
+
102
+ ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
103
+ ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
104
+
105
+ ## アップグレード
106
+
107
+ 新しいリリースがあった場合、以下のコマンドで更新できます。
108
+
109
+ ```powershell
110
+ cd sd-scripts
111
+ git pull
112
+ .\venv\Scripts\activate
113
+ pip install --use-pep517 --upgrade -r requirements.txt
114
+ ```
115
+
116
+ コマンドが成功すれば新しいバージョンが使用できます。
117
+
118
+ ## 謝意
119
+
120
+ LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
121
+
122
+ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
123
+
124
+ ## ライセンス
125
+
126
+ スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
127
+
128
+ [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
129
+
130
+ [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
131
+
132
+ [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
133
+
134
+ ## その他の情報
135
+
136
+ ### LoRAの名称について
137
+
138
+ `train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
139
+
140
+ 1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
141
+
142
+ Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
143
+
144
+ 2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
145
+
146
+ 1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
147
+
148
+ デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
149
+
150
+ <!--
151
+ LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
152
+
153
+ LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
154
+ -->
155
+
156
+ ### 学習中のサンプル画像生成
157
+
158
+ プロンプトファイルは例えば以下のようになります。
159
+
160
+ ```
161
+ # prompt 1
162
+ masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
163
+
164
+ # prompt 2
165
+ masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
166
+ ```
167
+
168
+ `#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
169
+
170
+ * `--n` Negative prompt up to the next option.
171
+ * `--w` Specifies the width of the generated image.
172
+ * `--h` Specifies the height of the generated image.
173
+ * `--d` Specifies the seed of the generated image.
174
+ * `--l` Specifies the CFG scale of the generated image.
175
+ * `--s` Specifies the number of steps in the generation.
176
+
177
+ `( )` や `[ ]` などの重みづけも動作します。
README.md CHANGED
@@ -1,12 +1,609 @@
1
- ---
2
- title: TrainL
3
- emoji: 📈
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.46.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repository contains training, generation and utility scripts for Stable Diffusion.
2
+
3
+ [__Change History__](#change-history) is moved to the bottom of the page.
4
+ 更新履歴は[ページ末尾](#change-history)に移しました。
5
+
6
+ Latest update: 2025-03-21 (Version 0.9.1)
7
+
8
+ [日本語版READMEはこちら](./README-ja.md)
9
+
10
+ The development version is in the `dev` branch. Please check the dev branch for the latest changes.
11
+
12
+ FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
13
+
14
+
15
+ For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
16
+
17
+ This repository contains the scripts for:
18
+
19
+ * DreamBooth training, including U-Net and Text Encoder
20
+ * Fine-tuning (native training), including U-Net and Text Encoder
21
+ * LoRA training
22
+ * Textual Inversion training
23
+ * Image generation
24
+ * Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
25
+
26
+ ### Sponsors
27
+
28
+ We are grateful to the following companies for their generous sponsorship:
29
+
30
+ <a href="https://aihub.co.jp/top-en">
31
+ <img src="./images/logo_aihub.png" alt="AiHUB Inc." title="AiHUB Inc." height="100px">
32
+ </a>
33
+
34
+ ### Support the Project
35
+
36
+ If you find this project helpful, please consider supporting its development via [GitHub Sponsors](https://github.com/sponsors/kohya-ss/). Your support is greatly appreciated!
37
+
38
+
39
+ ## About requirements.txt
40
+
41
+ The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
42
+
43
+ The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
44
+
45
+ ## Links to usage documentation
46
+
47
+ Most of the documents are written in Japanese.
48
+
49
+ [English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
50
+
51
+ * [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
52
+ * [Chinese version](./docs/train_README-zh.md)
53
+ * [SDXL training](./docs/train_SDXL-en.md) (English version)
54
+ * [Dataset config](./docs/config_README-ja.md)
55
+ * [English version](./docs/config_README-en.md)
56
+ * [DreamBooth training guide](./docs/train_db_README-ja.md)
57
+ * [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
58
+ * [Training LoRA](./docs/train_network_README-ja.md)
59
+ * [Training Textual Inversion](./docs/train_ti_README-ja.md)
60
+ * [Image generation](./docs/gen_img_README-ja.md)
61
+ * note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
62
+
63
+ ## Windows Required Dependencies
64
+
65
+ Python 3.10.6 and Git:
66
+
67
+ - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
68
+ - git: https://git-scm.com/download/win
69
+
70
+ Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
71
+
72
+ Give unrestricted script access to powershell so venv can work:
73
+
74
+ - Open an administrator powershell window
75
+ - Type `Set-ExecutionPolicy Unrestricted` and answer A
76
+ - Close admin powershell window
77
+
78
+ ## Windows Installation
79
+
80
+ Open a regular Powershell terminal and type the following inside:
81
+
82
+ ```powershell
83
+ git clone https://github.com/kohya-ss/sd-scripts.git
84
+ cd sd-scripts
85
+
86
+ python -m venv venv
87
+ .\venv\Scripts\activate
88
+
89
+ pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
90
+ pip install --upgrade -r requirements.txt
91
+ pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
92
+
93
+ accelerate config
94
+ ```
95
+
96
+ If `python -m venv` shows only `python`, change `python` to `py`.
97
+
98
+ Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
99
+
100
+ This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
101
+
102
+ If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
103
+
104
+ <!--
105
+ cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
106
+ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
107
+ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
108
+ -->
109
+ Answers to accelerate config:
110
+
111
+ ```txt
112
+ - This machine
113
+ - No distributed training
114
+ - NO
115
+ - NO
116
+ - NO
117
+ - all
118
+ - fp16
119
+ ```
120
+
121
+ If you'd like to use bf16, please answer `bf16` to the last question.
122
+
123
+ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
124
+ ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
125
+
126
+ (Single GPU with id `0` will be used.)
127
+
128
+ ## Upgrade
129
+
130
+ When a new release comes out you can upgrade your repo with the following command:
131
+
132
+ ```powershell
133
+ cd sd-scripts
134
+ git pull
135
+ .\venv\Scripts\activate
136
+ pip install --use-pep517 --upgrade -r requirements.txt
137
+ ```
138
+
139
+ Once the commands have completed successfully you should be ready to use the new version.
140
+
141
+ ### Upgrade PyTorch
142
+
143
+ If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
144
+
145
+ ## Credits
146
+
147
+ The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
148
+
149
+ The LoRA expansion to Conv2d 3x3 was initially released by cloneofsimo and its effectiveness was demonstrated at [LoCon](https://github.com/KohakuBlueleaf/LoCon) by KohakuBlueleaf. Thank you so much KohakuBlueleaf!
150
+
151
+ ## License
152
+
153
+ The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's and LoCon), however portions of the project are available under separate license terms:
154
+
155
+ [Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
156
+
157
+ [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
158
+
159
+ [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
160
+
161
+
162
+ ## Change History
163
+
164
+ ### Mar 21, 2025 / 2025-03-21 Version 0.9.1
165
+
166
+ - Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
167
+ - The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
168
+
169
+ ### Jan 17, 2025 / 2025-01-17 Version 0.9.0
170
+
171
+ - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
172
+ - bitsandbytes, transformers, accelerate and huggingface_hub are updated.
173
+ - If you encounter any issues, please report them.
174
+
175
+ - The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
176
+ - The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
177
+ - The following changes are included.
178
+
179
+ #### Changes
180
+
181
+ - Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
182
+ - Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
183
+
184
+ - There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
185
+
186
+ - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
187
+ - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
188
+
189
+ - Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
190
+
191
+ - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
192
+
193
+ - Improvements in OFT (Orthogonal Finetuning) Implementation
194
+ 1. Optimization of Calculation Order:
195
+ - Changed the calculation order in the forward method from (Wx)R to W(xR).
196
+ - This has improved computational efficiency and processing speed.
197
+ 2. Correction of Bias Application:
198
+ - In the previous implementation, R was incorrectly applied to the bias.
199
+ - The new implementation now correctly handles bias by using F.conv2d and F.linear.
200
+ 3. Efficiency Enhancement in Matrix Operations:
201
+ - Introduced einsum in both the forward and merge_to methods.
202
+ - This has optimized matrix operations, resulting in further speed improvements.
203
+ 4. Proper Handling of Data Types:
204
+ - Improved to use torch.float32 during calculations and convert results back to the original data type.
205
+ - This maintains precision while ensuring compatibility with the original model.
206
+ 5. Unified Processing for Conv2d and Linear Layers:
207
+ - Implemented a consistent method for applying OFT to both layer types.
208
+ - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
209
+
210
+ - Additional Information
211
+ * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
212
+
213
+ * Performance Improvement: Training speed has been improved by approximately 30%.
214
+
215
+ * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
216
+
217
+ - The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
218
+ - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
219
+ - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
220
+
221
+ - When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
222
+
223
+ - Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
224
+
225
+ - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
226
+
227
+ - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
228
+ - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
229
+ - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
230
+ - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
231
+ - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
232
+ - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
233
+ - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
234
+
235
+ - Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
236
+ - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
237
+ - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
238
+ - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
239
+ - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
240
+ - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
241
+
242
+ - LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
243
+ - LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
244
+ - Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
245
+ - `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
246
+ - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
247
+ - `network_module` `networks.lora` and `networks.dylora` are available.
248
+
249
+ - The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru!
250
+ - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file.
251
+ - See [About masked loss](./docs/masked_loss_README.md) for details.
252
+
253
+ - LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
254
+ - Specify the learning rate and dim (rank) for each block.
255
+ - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
256
+
257
+ - Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath!
258
+ - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended.
259
+ - When specifying from the command line, use `=` like `--learning_rate=-1e-7`.
260
+
261
+ - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
262
+ - Some settings, such as API keys and directory specifications, are not output due to security issues.
263
+
264
+ - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
265
+
266
+ - `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
267
+ - This resolves the issue where the order of data loading from DataSet changes when resuming training.
268
+ - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
269
+ - If `--resume` is specified, the step saved in the state is used.
270
+ - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
271
+
272
+ - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
273
+ - It seems that the model file loading is faster in the WSL environment etc.
274
+ - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
275
+
276
+ - When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath!
277
+
278
+ - Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291](
279
+ https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
280
+ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
281
+
282
+ - Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO!
283
+
284
+ - Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th!
285
+
286
+ - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
287
+
288
+ - Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
289
+
290
+ #### 変更点
291
+
292
+ - devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきま���。
293
+ - マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください。
294
+ - 以下の変更が含まれます。
295
+
296
+ - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
297
+ - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
298
+ - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
299
+ - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
300
+ - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
301
+ - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
302
+ - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
303
+
304
+ - SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
305
+ - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
306
+ - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
307
+ - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。
308
+ - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
309
+ - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
310
+
311
+ - LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
312
+ - LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
313
+ - `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
314
+ - `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
315
+ - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
316
+ - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
317
+
318
+ - 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。
319
+ - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。
320
+ - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。
321
+
322
+ - SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
323
+ - ブロックごとに学習率および dim (rank) を指定することができます。
324
+ - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
325
+
326
+ - `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。
327
+ - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。
328
+ - コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。
329
+
330
+ - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
331
+ - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
332
+
333
+ - SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
334
+
335
+ - `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
336
+ - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
337
+ - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます)
338
+ - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
339
+ - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
340
+
341
+ - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
342
+ - WSL 環境等でモデルファイルの読み込みが高速化されるようです。
343
+ - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
344
+
345
+ - ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。
346
+
347
+ - `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291](
348
+ https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
349
+ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。
350
+
351
+ - データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。
352
+
353
+ - ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。
354
+
355
+ - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
356
+
357
+ - `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
358
+
359
+
360
+ ### Oct 27, 2024 / 2024-10-27:
361
+
362
+ - `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
363
+ - This will be included in the next release.
364
+ - `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します(32GBあれば十分です)。
365
+ - これは次回リリースに含ま���ます。
366
+
367
+ ### Oct 26, 2024 / 2024-10-26:
368
+
369
+ - Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
370
+ - It will be included in the next release.
371
+
372
+ - `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
373
+ - 以上は次回リリースに含まれます。
374
+
375
+ ### Sep 13, 2024 / 2024-09-13:
376
+
377
+ - `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
378
+ - `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
379
+ - `sdxl_merge_lora.py` also supports LBW.
380
+ - See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
381
+ - These will be included in the next release.
382
+
383
+ - `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
384
+ - `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
385
+ - `sdxl_merge_lora.py` でも LBW がサポートされました。
386
+ - LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
387
+ - 以上は次回リリースに含まれます。
388
+
389
+ ### Jun 23, 2024 / 2024-06-23:
390
+
391
+ - Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
392
+
393
+ - `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
394
+
395
+ ### Apr 7, 2024 / 2024-04-07: v0.8.7
396
+
397
+ - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
398
+
399
+ - Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
400
+
401
+ ### Apr 7, 2024 / 2024-04-07: v0.8.6
402
+
403
+ #### Highlights
404
+
405
+ - The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
406
+ - Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
407
+ - `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
408
+ - `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
409
+ - Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
410
+ - When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
411
+ - A warning is displayed at the start of training if such information is included in the command line.
412
+ - Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
413
+ - See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
414
+ - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
415
+ - Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
416
+
417
+ #### Training scripts
418
+
419
+ - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
420
+ - Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
421
+ - DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
422
+ - The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
423
+ - Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
424
+ - The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
425
+ - The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
426
+ - The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
427
+
428
+ #### Dataset settings
429
+
430
+ - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
431
+ - The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
432
+ - Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
433
+ - Some features are added to the dataset subset settings.
434
+ - `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
435
+ - Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
436
+ - `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
437
+ - `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
438
+ - The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
439
+ - See [Dataset config](./docs/config_README-en.md) for details.
440
+ - The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
441
+
442
+ #### Image tagging
443
+
444
+ - The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
445
+ - Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
446
+ - The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
447
+ - Some options are added to `tag_image_by_wd14_tagger.py`.
448
+ - Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
449
+ - Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
450
+ - Output character tags first `--character_tags_first`
451
+ - Expand character tags and series `--character_tag_expand`
452
+ - Specify tags to output first `--always_first_tags`
453
+ - Replace tags `--tag_replacement`
454
+ - See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
455
+ - Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
456
+
457
+ #### About Masked loss
458
+
459
+ The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
460
+
461
+ The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
462
+
463
+ ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
464
+
465
+ #### About Scheduled Huber Loss
466
+
467
+ Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
468
+
469
+ With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
470
+
471
+ To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
472
+
473
+ Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
474
+
475
+ The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
476
+
477
+ See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
478
+
479
+ - `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
480
+ - `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
481
+ - `huber_c`: Specify the Huber's parameter. The default is `0.1`.
482
+
483
+ Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
484
+
485
+ #### 主要な変更点
486
+
487
+ - 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
488
+ - 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
489
+ - `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
490
+ - `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
491
+ - また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
492
+ - wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
493
+ - このような場合には学習開始時に警告が表示されます。
494
+ - また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
495
+ - 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
496
+ - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
497
+ - その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
498
+
499
+ #### 学習スクリプト
500
+
501
+ - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
502
+ - `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
503
+ - DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
504
+ - 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
505
+ - 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
506
+ - 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
507
+ - 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
508
+ - 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
509
+
510
+ #### データセット設定
511
+
512
+ - データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
513
+ - データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
514
+ - データセットのサブセット設定にいくつかの機能を追加しました。
515
+ - シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
516
+ - `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
517
+ - `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
518
+ - 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
519
+ - 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
520
+ - DreamBooth 方式の DataSet で画像情報(サイズ、キャプション)をキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
521
+ - データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
522
+
523
+ #### 画像のタグ付け
524
+
525
+ - `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
526
+ - Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
527
+ - `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデル��ァイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
528
+ - `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
529
+ - 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
530
+ - レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
531
+ - キャラクタタグを最初に出力する `--character_tags_first`
532
+ - キャラクタタグとシリーズを展開する `--character_tag_expand`
533
+ - 常に最初に出力するタグを指定する `--always_first_tags`
534
+ - タグを置換する `--tag_replacement`
535
+ - 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
536
+ - `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
537
+
538
+ #### マスクロスについて
539
+
540
+ 各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
541
+
542
+ 機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
543
+
544
+ マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
545
+
546
+ #### Scheduled Huber Loss について
547
+
548
+ 各学習スクリプトに、学習データ中の異常値や外れ値(data corruption)への耐性を高めるための手法、Scheduled Huber Lossが導入されました。
549
+
550
+ 従来のMSE(L2)損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
551
+
552
+ この手法ではHuber損失関数の適用を工夫し、学習の初期段階(ノイズが大きい場合)ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
553
+
554
+ 実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
555
+
556
+ 具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類(Huber, smooth L1, MSE)とスケジューリング方法(exponential, constant, SNR)を選択できます。これによりデータセットに応じた最適化が可能になります。
557
+
558
+ 詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
559
+
560
+ - `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
561
+ - `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
562
+ - `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
563
+
564
+ PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
565
+
566
+ 最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
567
+
568
+ ## Additional Information
569
+
570
+ ### Naming of LoRA
571
+
572
+ The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
573
+
574
+ 1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
575
+
576
+ LoRA for Linear layers and Conv2d layers with 1x1 kernel
577
+
578
+ 2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
579
+
580
+ In addition to 1., LoRA for Conv2d layers with 3x3 kernel
581
+
582
+ LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
583
+ <!--
584
+ LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
585
+
586
+ To use LoRA-C3Lier with Web UI, please use our extension.
587
+ -->
588
+
589
+ ### Sample image generation during training
590
+ A prompt file might look like this, for example
591
+
592
+ ```
593
+ # prompt 1
594
+ masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
595
+
596
+ # prompt 2
597
+ masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
598
+ ```
599
+
600
+ Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
601
+
602
+ * `--n` Negative prompt up to the next option.
603
+ * `--w` Specifies the width of the generated image.
604
+ * `--h` Specifies the height of the generated image.
605
+ * `--d` Specifies the seed of the generated image.
606
+ * `--l` Specifies the CFG scale of the generated image.
607
+ * `--s` Specifies the number of steps in the generation.
608
+
609
+ The prompt weighting such as `( )` and `[ ]` are working.
XTI_hijack.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from library.device_utils import init_ipex
3
+ init_ipex()
4
+
5
+ from typing import Union, List, Optional, Dict, Any, Tuple
6
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
7
+
8
+ from library.original_unet import SampleOutput
9
+
10
+
11
+ def unet_forward_XTI(
12
+ self,
13
+ sample: torch.FloatTensor,
14
+ timestep: Union[torch.Tensor, float, int],
15
+ encoder_hidden_states: torch.Tensor,
16
+ class_labels: Optional[torch.Tensor] = None,
17
+ return_dict: bool = True,
18
+ ) -> Union[Dict, Tuple]:
19
+ r"""
20
+ Args:
21
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
22
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
23
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
24
+ return_dict (`bool`, *optional*, defaults to `True`):
25
+ Whether or not to return a dict instead of a plain tuple.
26
+
27
+ Returns:
28
+ `SampleOutput` or `tuple`:
29
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
30
+ """
31
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
32
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
33
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
34
+ # on the fly if necessary.
35
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
36
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
37
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
38
+ default_overall_up_factor = 2**self.num_upsamplers
39
+
40
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
41
+ # 64で割り切れないときはupsamplerにサイズを伝える
42
+ forward_upsample_size = False
43
+ upsample_size = None
44
+
45
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
46
+ # logger.info("Forward upsample size to force interpolation output size.")
47
+ forward_upsample_size = True
48
+
49
+ # 1. time
50
+ timesteps = timestep
51
+ timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
52
+
53
+ t_emb = self.time_proj(timesteps)
54
+
55
+ # timesteps does not contain any weights and will always return f32 tensors
56
+ # but time_embedding might actually be running in fp16. so we need to cast here.
57
+ # there might be better ways to encapsulate this.
58
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
59
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
60
+ # time_projでキャストしておけばいいんじゃね?
61
+ t_emb = t_emb.to(dtype=self.dtype)
62
+ emb = self.time_embedding(t_emb)
63
+
64
+ # 2. pre-process
65
+ sample = self.conv_in(sample)
66
+
67
+ # 3. down
68
+ down_block_res_samples = (sample,)
69
+ down_i = 0
70
+ for downsample_block in self.down_blocks:
71
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
72
+ # まあこちらのほうがわかりやすいかもしれない
73
+ if downsample_block.has_cross_attention:
74
+ sample, res_samples = downsample_block(
75
+ hidden_states=sample,
76
+ temb=emb,
77
+ encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
78
+ )
79
+ down_i += 2
80
+ else:
81
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
82
+
83
+ down_block_res_samples += res_samples
84
+
85
+ # 4. mid
86
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
87
+
88
+ # 5. up
89
+ up_i = 7
90
+ for i, upsample_block in enumerate(self.up_blocks):
91
+ is_final_block = i == len(self.up_blocks) - 1
92
+
93
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
94
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
95
+
96
+ # if we have not reached the final block and need to forward the upsample size, we do it here
97
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
98
+ if not is_final_block and forward_upsample_size:
99
+ upsample_size = down_block_res_samples[-1].shape[2:]
100
+
101
+ if upsample_block.has_cross_attention:
102
+ sample = upsample_block(
103
+ hidden_states=sample,
104
+ temb=emb,
105
+ res_hidden_states_tuple=res_samples,
106
+ encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
107
+ upsample_size=upsample_size,
108
+ )
109
+ up_i += 3
110
+ else:
111
+ sample = upsample_block(
112
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
113
+ )
114
+
115
+ # 6. post-process
116
+ sample = self.conv_norm_out(sample)
117
+ sample = self.conv_act(sample)
118
+ sample = self.conv_out(sample)
119
+
120
+ if not return_dict:
121
+ return (sample,)
122
+
123
+ return SampleOutput(sample=sample)
124
+
125
+
126
+ def downblock_forward_XTI(
127
+ self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
128
+ ):
129
+ output_states = ()
130
+ i = 0
131
+
132
+ for resnet, attn in zip(self.resnets, self.attentions):
133
+ if self.training and self.gradient_checkpointing:
134
+
135
+ def create_custom_forward(module, return_dict=None):
136
+ def custom_forward(*inputs):
137
+ if return_dict is not None:
138
+ return module(*inputs, return_dict=return_dict)
139
+ else:
140
+ return module(*inputs)
141
+
142
+ return custom_forward
143
+
144
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
145
+ hidden_states = torch.utils.checkpoint.checkpoint(
146
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
147
+ )[0]
148
+ else:
149
+ hidden_states = resnet(hidden_states, temb)
150
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
151
+
152
+ output_states += (hidden_states,)
153
+ i += 1
154
+
155
+ if self.downsamplers is not None:
156
+ for downsampler in self.downsamplers:
157
+ hidden_states = downsampler(hidden_states)
158
+
159
+ output_states += (hidden_states,)
160
+
161
+ return hidden_states, output_states
162
+
163
+
164
+ def upblock_forward_XTI(
165
+ self,
166
+ hidden_states,
167
+ res_hidden_states_tuple,
168
+ temb=None,
169
+ encoder_hidden_states=None,
170
+ upsample_size=None,
171
+ ):
172
+ i = 0
173
+ for resnet, attn in zip(self.resnets, self.attentions):
174
+ # pop res hidden states
175
+ res_hidden_states = res_hidden_states_tuple[-1]
176
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
177
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
178
+
179
+ if self.training and self.gradient_checkpointing:
180
+
181
+ def create_custom_forward(module, return_dict=None):
182
+ def custom_forward(*inputs):
183
+ if return_dict is not None:
184
+ return module(*inputs, return_dict=return_dict)
185
+ else:
186
+ return module(*inputs)
187
+
188
+ return custom_forward
189
+
190
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
191
+ hidden_states = torch.utils.checkpoint.checkpoint(
192
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
193
+ )[0]
194
+ else:
195
+ hidden_states = resnet(hidden_states, temb)
196
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
197
+
198
+ i += 1
199
+
200
+ if self.upsamplers is not None:
201
+ for upsampler in self.upsamplers:
202
+ hidden_states = upsampler(hidden_states, upsample_size)
203
+
204
+ return hidden_states
_typos.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Files for typos
2
+ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started
3
+
4
+ [default.extend-identifiers]
5
+ ddPn08="ddPn08"
6
+
7
+ [default.extend-words]
8
+ NIN="NIN"
9
+ parms="parms"
10
+ nin="nin"
11
+ extention="extention" # Intentionally left
12
+ nd="nd"
13
+ shs="shs"
14
+ sts="sts"
15
+ scs="scs"
16
+ cpc="cpc"
17
+ coc="coc"
18
+ cic="cic"
19
+ msm="msm"
20
+ usu="usu"
21
+ ici="ici"
22
+ lvl="lvl"
23
+ dii="dii"
24
+ muk="muk"
25
+ ori="ori"
26
+ hru="hru"
27
+ rik="rik"
28
+ koo="koo"
29
+ yos="yos"
30
+ wn="wn"
31
+ hime="hime"
32
+
33
+
34
+ [files]
35
+ extend-exclude = ["_typos.toml", "venv"]
app.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LoRA Trainer Funcional para Hugging Face
4
+ Baseado no kohya-ss sd-scripts
5
+ """
6
+
7
+ import gradio as gr
8
+ import os
9
+ import sys
10
+ import json
11
+ import subprocess
12
+ import shutil
13
+ import zipfile
14
+ import tempfile
15
+ import toml
16
+ import logging
17
+ from pathlib import Path
18
+ from typing import Optional, Tuple, List, Dict, Any
19
+ import time
20
+ import threading
21
+ import queue
22
+
23
+ # Adicionar o diretório sd-scripts ao path
24
+ sys.path.insert(0, str(Path(__file__).parent / "sd-scripts"))
25
+
26
+ # Configurar logging
27
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class LoRATrainerHF:
31
+ def __init__(self):
32
+ self.base_dir = Path("/tmp/lora_training")
33
+ self.base_dir.mkdir(exist_ok=True)
34
+
35
+ self.models_dir = self.base_dir / "models"
36
+ self.models_dir.mkdir(exist_ok=True)
37
+
38
+ self.projects_dir = self.base_dir / "projects"
39
+ self.projects_dir.mkdir(exist_ok=True)
40
+
41
+ self.sd_scripts_dir = Path(__file__).parent / "sd-scripts"
42
+
43
+ # URLs dos modelos
44
+ self.model_urls = {
45
+ "Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors",
46
+ "AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt",
47
+ "Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors",
48
+ "Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt"
49
+ }
50
+
51
+ self.training_process = None
52
+ self.training_output_queue = queue.Queue()
53
+
54
+ def install_dependencies(self) -> str:
55
+ """Instala as dependências necessárias"""
56
+ try:
57
+ logger.info("Instalando dependências...")
58
+
59
+ # Lista de pacotes necessários
60
+ packages = [
61
+ "torch>=2.0.0",
62
+ "torchvision>=0.15.0",
63
+ "diffusers>=0.21.0",
64
+ "transformers>=4.25.0",
65
+ "accelerate>=0.20.0",
66
+ "safetensors>=0.3.0",
67
+ "huggingface-hub>=0.16.0",
68
+ "xformers>=0.0.20",
69
+ "bitsandbytes>=0.41.0",
70
+ "opencv-python>=4.7.0",
71
+ "Pillow>=9.0.0",
72
+ "numpy>=1.21.0",
73
+ "tqdm>=4.64.0",
74
+ "toml>=0.10.0",
75
+ "tensorboard>=2.13.0",
76
+ "wandb>=0.15.0",
77
+ "scipy>=1.9.0",
78
+ "matplotlib>=3.5.0",
79
+ "datasets>=2.14.0",
80
+ "peft>=0.5.0",
81
+ "omegaconf>=2.3.0"
82
+ ]
83
+
84
+ # Instalar pacotes
85
+ for package in packages:
86
+ try:
87
+ subprocess.run([
88
+ sys.executable, "-m", "pip", "install", package, "--quiet"
89
+ ], check=True, capture_output=True, text=True)
90
+ logger.info(f"✓ {package} instalado")
91
+ except subprocess.CalledProcessError as e:
92
+ logger.warning(f"⚠ Erro ao instalar {package}: {e}")
93
+
94
+ return "✅ Dependências instaladas com sucesso!"
95
+
96
+ except Exception as e:
97
+ logger.error(f"Erro ao instalar dependências: {e}")
98
+ return f"❌ Erro ao instalar dependências: {e}"
99
+
100
+ def download_model(self, model_choice: str, custom_url: str = "") -> str:
101
+ """Download do modelo base"""
102
+ try:
103
+ if custom_url.strip():
104
+ model_url = custom_url.strip()
105
+ model_name = model_url.split("/")[-1]
106
+ else:
107
+ if model_choice not in self.model_urls:
108
+ return f"❌ Modelo '{model_choice}' não encontrado"
109
+ model_url = self.model_urls[model_choice]
110
+ model_name = model_url.split("/")[-1]
111
+
112
+ model_path = self.models_dir / model_name
113
+
114
+ if model_path.exists():
115
+ return f"✅ Modelo já existe: {model_name}"
116
+
117
+ logger.info(f"Baixando modelo: {model_url}")
118
+
119
+ # Download usando wget
120
+ result = subprocess.run([
121
+ "wget", "-O", str(model_path), model_url, "--progress=bar:force"
122
+ ], capture_output=True, text=True)
123
+
124
+ if result.returncode == 0:
125
+ return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)"
126
+ else:
127
+ return f"❌ Erro no download: {result.stderr}"
128
+
129
+ except Exception as e:
130
+ logger.error(f"Erro ao baixar modelo: {e}")
131
+ return f"❌ Erro ao baixar modelo: {e}"
132
+
133
+ def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]:
134
+ """Processa o dataset enviado"""
135
+ try:
136
+ if not dataset_zip:
137
+ return "❌ Nenhum dataset foi enviado", ""
138
+
139
+ if not project_name.strip():
140
+ return "❌ Nome do projeto é obrigatório", ""
141
+
142
+ project_name = project_name.strip().replace(" ", "_")
143
+ project_dir = self.projects_dir / project_name
144
+ project_dir.mkdir(exist_ok=True)
145
+
146
+ dataset_dir = project_dir / "dataset"
147
+ if dataset_dir.exists():
148
+ shutil.rmtree(dataset_dir)
149
+ dataset_dir.mkdir()
150
+
151
+ # Extrair ZIP
152
+ with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref:
153
+ zip_ref.extractall(dataset_dir)
154
+
155
+ # Analisar dataset
156
+ image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
157
+ images = []
158
+ captions = []
159
+
160
+ for file_path in dataset_dir.rglob("*"):
161
+ if file_path.suffix.lower() in image_extensions:
162
+ images.append(file_path)
163
+
164
+ # Procurar caption
165
+ caption_path = file_path.with_suffix('.txt')
166
+ if caption_path.exists():
167
+ captions.append(caption_path)
168
+
169
+ info = f"✅ Dataset processado!\n"
170
+ info += f"📁 Projeto: {project_name}\n"
171
+ info += f"🖼️ Imagens: {len(images)}\n"
172
+ info += f"📝 Captions: {len(captions)}\n"
173
+ info += f"📂 Diretório: {dataset_dir}"
174
+
175
+ return info, str(dataset_dir)
176
+
177
+ except Exception as e:
178
+ logger.error(f"Erro ao processar dataset: {e}")
179
+ return f"❌ Erro ao processar dataset: {e}", ""
180
+
181
+ def create_training_config(self,
182
+ project_name: str,
183
+ dataset_dir: str,
184
+ model_choice: str,
185
+ custom_model_url: str,
186
+ resolution: int,
187
+ batch_size: int,
188
+ epochs: int,
189
+ learning_rate: float,
190
+ text_encoder_lr: float,
191
+ network_dim: int,
192
+ network_alpha: int,
193
+ lora_type: str,
194
+ optimizer: str,
195
+ scheduler: str,
196
+ flip_aug: bool,
197
+ shuffle_caption: bool,
198
+ keep_tokens: int,
199
+ clip_skip: int,
200
+ mixed_precision: str,
201
+ save_every_n_epochs: int,
202
+ max_train_steps: int) -> str:
203
+ """Cria configuração de treinamento"""
204
+ try:
205
+ if not project_name.strip():
206
+ return "❌ Nome do projeto é obrigatório"
207
+
208
+ project_name = project_name.strip().replace(" ", "_")
209
+ project_dir = self.projects_dir / project_name
210
+ project_dir.mkdir(exist_ok=True)
211
+
212
+ output_dir = project_dir / "output"
213
+ output_dir.mkdir(exist_ok=True)
214
+
215
+ log_dir = project_dir / "logs"
216
+ log_dir.mkdir(exist_ok=True)
217
+
218
+ # Determinar modelo
219
+ if custom_model_url.strip():
220
+ model_name = custom_model_url.strip().split("/")[-1]
221
+ else:
222
+ model_name = self.model_urls[model_choice].split("/")[-1]
223
+
224
+ model_path = self.models_dir / model_name
225
+
226
+ if not model_path.exists():
227
+ return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro."
228
+
229
+ # Configuração do dataset
230
+ dataset_config = {
231
+ "general": {
232
+ "shuffle_caption": shuffle_caption,
233
+ "caption_extension": ".txt",
234
+ "keep_tokens": keep_tokens,
235
+ "flip_aug": flip_aug,
236
+ "color_aug": False,
237
+ "face_crop_aug_range": None,
238
+ "random_crop": False,
239
+ "debug_dataset": False
240
+ },
241
+ "datasets": [{
242
+ "resolution": resolution,
243
+ "batch_size": batch_size,
244
+ "subsets": [{
245
+ "image_dir": str(dataset_dir),
246
+ "num_repeats": 1
247
+ }]
248
+ }]
249
+ }
250
+
251
+ # Configuração de treinamento
252
+ training_config = {
253
+ "model_arguments": {
254
+ "pretrained_model_name_or_path": str(model_path),
255
+ "v2": False,
256
+ "v_parameterization": False,
257
+ "clip_skip": clip_skip
258
+ },
259
+ "dataset_arguments": {
260
+ "dataset_config": str(project_dir / "dataset_config.toml")
261
+ },
262
+ "training_arguments": {
263
+ "output_dir": str(output_dir),
264
+ "output_name": project_name,
265
+ "save_precision": "fp16",
266
+ "save_every_n_epochs": save_every_n_epochs,
267
+ "max_train_epochs": epochs if max_train_steps == 0 else None,
268
+ "max_train_steps": max_train_steps if max_train_steps > 0 else None,
269
+ "train_batch_size": batch_size,
270
+ "gradient_accumulation_steps": 1,
271
+ "learning_rate": learning_rate,
272
+ "text_encoder_lr": text_encoder_lr,
273
+ "lr_scheduler": scheduler,
274
+ "lr_warmup_steps": 0,
275
+ "optimizer_type": optimizer,
276
+ "mixed_precision": mixed_precision,
277
+ "save_model_as": "safetensors",
278
+ "seed": 42,
279
+ "max_data_loader_n_workers": 2,
280
+ "persistent_data_loader_workers": True,
281
+ "gradient_checkpointing": True,
282
+ "xformers": True,
283
+ "lowram": True,
284
+ "cache_latents": True,
285
+ "cache_latents_to_disk": True,
286
+ "logging_dir": str(log_dir),
287
+ "log_with": "tensorboard"
288
+ },
289
+ "network_arguments": {
290
+ "network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora",
291
+ "network_dim": network_dim,
292
+ "network_alpha": network_alpha,
293
+ "network_train_unet_only": False,
294
+ "network_train_text_encoder_only": False
295
+ }
296
+ }
297
+
298
+ # Adicionar argumentos específicos para LoCon
299
+ if lora_type == "LoCon":
300
+ training_config["network_arguments"]["network_module"] = "networks.lora"
301
+ training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2)
302
+ training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2)
303
+
304
+ # Salvar configurações
305
+ dataset_config_path = project_dir / "dataset_config.toml"
306
+ training_config_path = project_dir / "training_config.toml"
307
+
308
+ with open(dataset_config_path, 'w') as f:
309
+ toml.dump(dataset_config, f)
310
+
311
+ with open(training_config_path, 'w') as f:
312
+ toml.dump(training_config, f)
313
+
314
+ return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}"
315
+
316
+ except Exception as e:
317
+ logger.error(f"Erro ao criar configuração: {e}")
318
+ return f"❌ Erro ao criar configuração: {e}"
319
+
320
+ def start_training(self, project_name: str) -> str:
321
+ """Inicia o treinamento"""
322
+ try:
323
+ if not project_name.strip():
324
+ return "❌ Nome do projeto é obrigatório"
325
+
326
+ project_name = project_name.strip().replace(" ", "_")
327
+ project_dir = self.projects_dir / project_name
328
+
329
+ training_config_path = project_dir / "training_config.toml"
330
+ if not training_config_path.exists():
331
+ return "❌ Configuração não encontrada. Crie a configuração primeiro."
332
+
333
+ # Script de treinamento
334
+ train_script = self.sd_scripts_dir / "train_network.py"
335
+ if not train_script.exists():
336
+ return "❌ Script de treinamento não encontrado"
337
+
338
+ # Comando de treinamento
339
+ cmd = [
340
+ sys.executable,
341
+ str(train_script),
342
+ "--config_file", str(training_config_path)
343
+ ]
344
+
345
+ logger.info(f"Iniciando treinamento: {' '.join(cmd)}")
346
+
347
+ # Executar em thread separada
348
+ def run_training():
349
+ try:
350
+ process = subprocess.Popen(
351
+ cmd,
352
+ stdout=subprocess.PIPE,
353
+ stderr=subprocess.STDOUT,
354
+ text=True,
355
+ bufsize=1,
356
+ universal_newlines=True,
357
+ cwd=str(self.sd_scripts_dir)
358
+ )
359
+
360
+ self.training_process = process
361
+
362
+ for line in process.stdout:
363
+ self.training_output_queue.put(line.strip())
364
+ logger.info(line.strip())
365
+
366
+ process.wait()
367
+
368
+ if process.returncode == 0:
369
+ self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!")
370
+ else:
371
+ self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})")
372
+
373
+ except Exception as e:
374
+ self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}")
375
+ finally:
376
+ self.training_process = None
377
+
378
+ # Iniciar thread
379
+ training_thread = threading.Thread(target=run_training)
380
+ training_thread.daemon = True
381
+ training_thread.start()
382
+
383
+ return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo."
384
+
385
+ except Exception as e:
386
+ logger.error(f"Erro ao iniciar treinamento: {e}")
387
+ return f"❌ Erro ao iniciar treinamento: {e}"
388
+
389
+ def get_training_output(self) -> str:
390
+ """Obtém output do treinamento"""
391
+ output_lines = []
392
+ try:
393
+ while not self.training_output_queue.empty():
394
+ line = self.training_output_queue.get_nowait()
395
+ output_lines.append(line)
396
+ except queue.Empty:
397
+ pass
398
+
399
+ if output_lines:
400
+ return "\n".join(output_lines)
401
+ elif self.training_process and self.training_process.poll() is None:
402
+ return "🔄 Treinamento em andamento..."
403
+ else:
404
+ return "⏸️ Nenhum treinamento ativo"
405
+
406
+ def stop_training(self) -> str:
407
+ """Para o treinamento"""
408
+ try:
409
+ if self.training_process and self.training_process.poll() is None:
410
+ self.training_process.terminate()
411
+ self.training_process.wait(timeout=10)
412
+ return "⏹️ Treinamento interrompido"
413
+ else:
414
+ return "ℹ️ Nenhum treinamento ativo para parar"
415
+ except Exception as e:
416
+ return f"❌ Erro ao parar treinamento: {e}"
417
+
418
+ def list_output_files(self, project_name: str) -> List[str]:
419
+ """Lista arquivos de saída"""
420
+ try:
421
+ if not project_name.strip():
422
+ return []
423
+
424
+ project_name = project_name.strip().replace(" ", "_")
425
+ project_dir = self.projects_dir / project_name
426
+ output_dir = project_dir / "output"
427
+
428
+ if not output_dir.exists():
429
+ return []
430
+
431
+ files = []
432
+ for file_path in output_dir.rglob("*.safetensors"):
433
+ size_mb = file_path.stat().st_size // (1024 * 1024)
434
+ files.append(f"{file_path.name} ({size_mb} MB)")
435
+
436
+ return sorted(files, reverse=True) # Mais recentes primeiro
437
+
438
+ except Exception as e:
439
+ logger.error(f"Erro ao listar arquivos: {e}")
440
+ return []
441
+
442
+ # Instância global
443
+ trainer = LoRATrainerHF()
444
+
445
+ def create_interface():
446
+ """Cria a interface Gradio"""
447
+
448
+ with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface:
449
+
450
+ gr.Markdown("""
451
+ # 🎨 LoRA Trainer Funcional para Hugging Face
452
+
453
+ **Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!**
454
+
455
+ Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA.
456
+ """)
457
+
458
+ # Estado para armazenar informações
459
+ dataset_dir_state = gr.State("")
460
+
461
+ with gr.Tab("🔧 Instalação"):
462
+ gr.Markdown("### Primeiro, instale as dependências necessárias:")
463
+ install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg")
464
+ install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False)
465
+
466
+ install_btn.click(
467
+ fn=trainer.install_dependencies,
468
+ outputs=install_status
469
+ )
470
+
471
+ with gr.Tab("📁 Configuração do Projeto"):
472
+ with gr.Row():
473
+ project_name = gr.Textbox(
474
+ label="Nome do Projeto",
475
+ placeholder="meu_lora_anime",
476
+ info="Nome único para seu projeto (sem espaços especiais)"
477
+ )
478
+
479
+ gr.Markdown("### 📥 Download do Modelo Base")
480
+ with gr.Row():
481
+ model_choice = gr.Dropdown(
482
+ choices=list(trainer.model_urls.keys()),
483
+ label="Modelo Base Pré-definido",
484
+ value="Anime (animefull-final-pruned)",
485
+ info="Escolha um modelo base ou use URL personalizada"
486
+ )
487
+ custom_model_url = gr.Textbox(
488
+ label="URL Personalizada (opcional)",
489
+ placeholder="https://huggingface.co/...",
490
+ info="URL direta para download de modelo personalizado"
491
+ )
492
+
493
+ download_btn = gr.Button("📥 Baixar Modelo", variant="primary")
494
+ download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False)
495
+
496
+ gr.Markdown("### 📊 Upload do Dataset")
497
+ gr.Markdown("""
498
+ **Formato do Dataset:**
499
+ - Crie um arquivo ZIP contendo suas imagens
500
+ - Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições
501
+ - Exemplo: `imagem1.jpg` + `imagem1.txt`
502
+ """)
503
+
504
+ dataset_upload = gr.File(
505
+ label="Upload do Dataset (ZIP)",
506
+ file_types=[".zip"]
507
+ )
508
+
509
+ process_btn = gr.Button("📊 Processar Dataset", variant="primary")
510
+ dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False)
511
+
512
+ with gr.Tab("⚙️ Parâmetros de Treinamento"):
513
+ with gr.Row():
514
+ with gr.Column():
515
+ gr.Markdown("#### 🖼️ Configurações de Imagem")
516
+ resolution = gr.Slider(
517
+ minimum=512, maximum=1024, step=64, value=512,
518
+ label="Resolução",
519
+ info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)"
520
+ )
521
+ batch_size = gr.Slider(
522
+ minimum=1, maximum=8, step=1, value=1,
523
+ label="Batch Size",
524
+ info="Imagens por lote (aumente se tiver GPU potente)"
525
+ )
526
+ flip_aug = gr.Checkbox(
527
+ label="Flip Augmentation",
528
+ info="Espelhar imagens para aumentar dataset"
529
+ )
530
+ shuffle_caption = gr.Checkbox(
531
+ value=True,
532
+ label="Shuffle Caption",
533
+ info="Embaralhar ordem das tags"
534
+ )
535
+ keep_tokens = gr.Slider(
536
+ minimum=0, maximum=5, step=1, value=1,
537
+ label="Keep Tokens",
538
+ info="Número de tokens iniciais que não serão embaralhados"
539
+ )
540
+
541
+ with gr.Column():
542
+ gr.Markdown("#### 🎯 Configurações de Treinamento")
543
+ epochs = gr.Slider(
544
+ minimum=1, maximum=100, step=1, value=10,
545
+ label="Épocas",
546
+ info="Número de épocas de treinamento"
547
+ )
548
+ max_train_steps = gr.Number(
549
+ value=0,
550
+ label="Max Train Steps (0 = usar épocas)",
551
+ info="Número máximo de steps (deixe 0 para usar épocas)"
552
+ )
553
+ save_every_n_epochs = gr.Slider(
554
+ minimum=1, maximum=10, step=1, value=1,
555
+ label="Salvar a cada N épocas",
556
+ info="Frequência de salvamento dos checkpoints"
557
+ )
558
+ mixed_precision = gr.Dropdown(
559
+ choices=["fp16", "bf16", "no"],
560
+ value="fp16",
561
+ label="Mixed Precision",
562
+ info="fp16 = mais rápido, bf16 = mais estável"
563
+ )
564
+ clip_skip = gr.Slider(
565
+ minimum=1, maximum=12, step=1, value=2,
566
+ label="CLIP Skip",
567
+ info="Camadas CLIP a pular (2 para anime, 1 para realista)"
568
+ )
569
+
570
+ with gr.Row():
571
+ with gr.Column():
572
+ gr.Markdown("#### 📚 Learning Rate")
573
+ learning_rate = gr.Number(
574
+ value=1e-4,
575
+ label="Learning Rate (UNet)",
576
+ info="Taxa de aprendizado principal"
577
+ )
578
+ text_encoder_lr = gr.Number(
579
+ value=5e-5,
580
+ label="Learning Rate (Text Encoder)",
581
+ info="Taxa de aprendizado do text encoder"
582
+ )
583
+ scheduler = gr.Dropdown(
584
+ choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"],
585
+ value="cosine_with_restarts",
586
+ label="LR Scheduler",
587
+ info="Algoritmo de ajuste da learning rate"
588
+ )
589
+ optimizer = gr.Dropdown(
590
+ choices=["AdamW8bit", "AdamW", "Lion", "SGD"],
591
+ value="AdamW8bit",
592
+ label="Otimizador",
593
+ info="AdamW8bit = menos memória"
594
+ )
595
+
596
+ with gr.Column():
597
+ gr.Markdown("#### 🧠 Arquitetura LoRA")
598
+ lora_type = gr.Radio(
599
+ choices=["LoRA", "LoCon"],
600
+ value="LoRA",
601
+ label="Tipo de LoRA",
602
+ info="LoRA = geral, LoCon = estilos artísticos"
603
+ )
604
+ network_dim = gr.Slider(
605
+ minimum=4, maximum=128, step=4, value=32,
606
+ label="Network Dimension",
607
+ info="Dimensão da rede (maior = mais detalhes, mais memória)"
608
+ )
609
+ network_alpha = gr.Slider(
610
+ minimum=1, maximum=128, step=1, value=16,
611
+ label="Network Alpha",
612
+ info="Controla a força do LoRA (geralmente dim/2)"
613
+ )
614
+
615
+ with gr.Tab("🚀 Treinamento"):
616
+ create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg")
617
+ config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False)
618
+
619
+ with gr.Row():
620
+ start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg")
621
+ stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop")
622
+
623
+ training_output = gr.Textbox(
624
+ label="Output do Treinamento",
625
+ lines=15,
626
+ interactive=False,
627
+ info="Acompanhe o progresso do treinamento em tempo real"
628
+ )
629
+
630
+ # Auto-refresh do output
631
+ def update_output():
632
+ return trainer.get_training_output()
633
+
634
+ with gr.Tab("📥 Download dos Resultados"):
635
+ refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary")
636
+
637
+ output_files = gr.Dropdown(
638
+ label="Arquivos LoRA Gerados",
639
+ choices=[],
640
+ info="Selecione um arquivo para download"
641
+ )
642
+
643
+ download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento")
644
+
645
+ # Event handlers
646
+ download_btn.click(
647
+ fn=trainer.download_model,
648
+ inputs=[model_choice, custom_model_url],
649
+ outputs=download_status
650
+ )
651
+
652
+ process_btn.click(
653
+ fn=trainer.process_dataset,
654
+ inputs=[dataset_upload, project_name],
655
+ outputs=[dataset_status, dataset_dir_state]
656
+ )
657
+
658
+ create_config_btn.click(
659
+ fn=trainer.create_training_config,
660
+ inputs=[
661
+ project_name, dataset_dir_state, model_choice, custom_model_url,
662
+ resolution, batch_size, epochs, learning_rate, text_encoder_lr,
663
+ network_dim, network_alpha, lora_type, optimizer, scheduler,
664
+ flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision,
665
+ save_every_n_epochs, max_train_steps
666
+ ],
667
+ outputs=config_status
668
+ )
669
+
670
+ start_training_btn.click(
671
+ fn=trainer.start_training,
672
+ inputs=project_name,
673
+ outputs=training_output
674
+ )
675
+
676
+ stop_training_btn.click(
677
+ fn=trainer.stop_training,
678
+ outputs=training_output
679
+ )
680
+
681
+ refresh_files_btn.click(
682
+ fn=trainer.list_output_files,
683
+ inputs=project_name,
684
+ outputs=output_files
685
+ )
686
+
687
+ return interface
688
+
689
+ if __name__ == "__main__":
690
+ print("🚀 Iniciando LoRA Trainer Funcional...")
691
+ interface = create_interface()
692
+ interface.launch(
693
+ server_name="0.0.0.0",
694
+ server_port=7860,
695
+ share=False,
696
+ show_error=True
697
+ )
698
+
fine_tune.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+ # XXX dropped option: hypernetwork training
3
+
4
+ import argparse
5
+ import math
6
+ import os
7
+ from multiprocessing import Value
8
+ import toml
9
+
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from library import deepspeed_utils
14
+ from library.device_utils import init_ipex, clean_memory_on_device
15
+
16
+ init_ipex()
17
+
18
+ from accelerate.utils import set_seed
19
+ from diffusers import DDPMScheduler
20
+
21
+ from library.utils import setup_logging, add_logging_arguments
22
+
23
+ setup_logging()
24
+ import logging
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ import library.train_util as train_util
29
+ import library.config_util as config_util
30
+ from library.config_util import (
31
+ ConfigSanitizer,
32
+ BlueprintGenerator,
33
+ )
34
+ import library.custom_train_functions as custom_train_functions
35
+ from library.custom_train_functions import (
36
+ apply_snr_weight,
37
+ get_weighted_text_embeddings,
38
+ prepare_scheduler_for_custom_training,
39
+ scale_v_prediction_loss_like_noise_prediction,
40
+ apply_debiased_estimation,
41
+ )
42
+
43
+
44
+ def train(args):
45
+ train_util.verify_training_args(args)
46
+ train_util.prepare_dataset_args(args, True)
47
+ deepspeed_utils.prepare_deepspeed_args(args)
48
+ setup_logging(args, reset=True)
49
+
50
+ cache_latents = args.cache_latents
51
+
52
+ if args.seed is not None:
53
+ set_seed(args.seed) # 乱数系列を初期化する
54
+
55
+ tokenizer = train_util.load_tokenizer(args)
56
+
57
+ # データセットを準備する
58
+ if args.dataset_class is None:
59
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
60
+ if args.dataset_config is not None:
61
+ logger.info(f"Load dataset config from {args.dataset_config}")
62
+ user_config = config_util.load_user_config(args.dataset_config)
63
+ ignored = ["train_data_dir", "in_json"]
64
+ if any(getattr(args, attr) is not None for attr in ignored):
65
+ logger.warning(
66
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
67
+ ", ".join(ignored)
68
+ )
69
+ )
70
+ else:
71
+ user_config = {
72
+ "datasets": [
73
+ {
74
+ "subsets": [
75
+ {
76
+ "image_dir": args.train_data_dir,
77
+ "metadata_file": args.in_json,
78
+ }
79
+ ]
80
+ }
81
+ ]
82
+ }
83
+
84
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
85
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
86
+ else:
87
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
88
+
89
+ current_epoch = Value("i", 0)
90
+ current_step = Value("i", 0)
91
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
92
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
93
+
94
+ train_dataset_group.verify_bucket_reso_steps(64)
95
+
96
+ if args.debug_dataset:
97
+ train_util.debug_dataset(train_dataset_group)
98
+ return
99
+ if len(train_dataset_group) == 0:
100
+ logger.error(
101
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
102
+ )
103
+ return
104
+
105
+ if cache_latents:
106
+ assert (
107
+ train_dataset_group.is_latent_cacheable()
108
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
109
+
110
+ # acceleratorを準備する
111
+ logger.info("prepare accelerator")
112
+ accelerator = train_util.prepare_accelerator(args)
113
+
114
+ # mixed precisionに対応した型を用意しておき適宜castする
115
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
116
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
117
+
118
+ # モデルを読み込む
119
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
120
+
121
+ # verify load/save model formats
122
+ if load_stable_diffusion_format:
123
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
124
+ src_diffusers_model_path = None
125
+ else:
126
+ src_stable_diffusion_ckpt = None
127
+ src_diffusers_model_path = args.pretrained_model_name_or_path
128
+
129
+ if args.save_model_as is None:
130
+ save_stable_diffusion_format = load_stable_diffusion_format
131
+ use_safetensors = args.use_safetensors
132
+ else:
133
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
134
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
135
+
136
+ # Diffusers版のxformers使用フラグを設定する関数
137
+ def set_diffusers_xformers_flag(model, valid):
138
+ # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
139
+ # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
140
+ # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
141
+ # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^)
142
+
143
+ # Recursively walk through all the children.
144
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
145
+ # gets the message
146
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
147
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
148
+ module.set_use_memory_efficient_attention_xformers(valid)
149
+
150
+ for child in module.children():
151
+ fn_recursive_set_mem_eff(child)
152
+
153
+ fn_recursive_set_mem_eff(model)
154
+
155
+ # モデルに xformers とか memory efficient attention を組み込む
156
+ if args.diffusers_xformers:
157
+ accelerator.print("Use xformers by Diffusers")
158
+ set_diffusers_xformers_flag(unet, True)
159
+ else:
160
+ # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
161
+ accelerator.print("Disable Diffusers' xformers")
162
+ set_diffusers_xformers_flag(unet, False)
163
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
164
+
165
+ # 学習を準備する
166
+ if cache_latents:
167
+ vae.to(accelerator.device, dtype=vae_dtype)
168
+ vae.requires_grad_(False)
169
+ vae.eval()
170
+ with torch.no_grad():
171
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
172
+ vae.to("cpu")
173
+ clean_memory_on_device(accelerator.device)
174
+
175
+ accelerator.wait_for_everyone()
176
+
177
+ # 学習を準備する:モデルを適切な状態にする
178
+ training_models = []
179
+ if args.gradient_checkpointing:
180
+ unet.enable_gradient_checkpointing()
181
+ training_models.append(unet)
182
+
183
+ if args.train_text_encoder:
184
+ accelerator.print("enable text encoder training")
185
+ if args.gradient_checkpointing:
186
+ text_encoder.gradient_checkpointing_enable()
187
+ training_models.append(text_encoder)
188
+ else:
189
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
190
+ text_encoder.requires_grad_(False) # text encoderは学習しない
191
+ if args.gradient_checkpointing:
192
+ text_encoder.gradient_checkpointing_enable()
193
+ text_encoder.train() # required for gradient_checkpointing
194
+ else:
195
+ text_encoder.eval()
196
+
197
+ if not cache_latents:
198
+ vae.requires_grad_(False)
199
+ vae.eval()
200
+ vae.to(accelerator.device, dtype=vae_dtype)
201
+
202
+ for m in training_models:
203
+ m.requires_grad_(True)
204
+
205
+ trainable_params = []
206
+ if args.learning_rate_te is None or not args.train_text_encoder:
207
+ for m in training_models:
208
+ trainable_params.extend(m.parameters())
209
+ else:
210
+ trainable_params = [
211
+ {"params": list(unet.parameters()), "lr": args.learning_rate},
212
+ {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
213
+ ]
214
+
215
+ # 学習に必要なクラスを準備する
216
+ accelerator.print("prepare optimizer, data loader etc.")
217
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
218
+
219
+ # dataloaderを準備する
220
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
221
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
222
+ train_dataloader = torch.utils.data.DataLoader(
223
+ train_dataset_group,
224
+ batch_size=1,
225
+ shuffle=True,
226
+ collate_fn=collator,
227
+ num_workers=n_workers,
228
+ persistent_workers=args.persistent_data_loader_workers,
229
+ )
230
+
231
+ # 学習ステップ数を計算する
232
+ if args.max_train_epochs is not None:
233
+ args.max_train_steps = args.max_train_epochs * math.ceil(
234
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
235
+ )
236
+ accelerator.print(
237
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
238
+ )
239
+
240
+ # データセット側にも学習ステップを送信
241
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
242
+
243
+ # lr schedulerを用意する
244
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
245
+
246
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
247
+ if args.full_fp16:
248
+ assert (
249
+ args.mixed_precision == "fp16"
250
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
251
+ accelerator.print("enable full fp16 training.")
252
+ unet.to(weight_dtype)
253
+ text_encoder.to(weight_dtype)
254
+
255
+ if args.deepspeed:
256
+ if args.train_text_encoder:
257
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
258
+ else:
259
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
260
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
261
+ ds_model, optimizer, train_dataloader, lr_scheduler
262
+ )
263
+ training_models = [ds_model]
264
+ else:
265
+ # acceleratorがなんかよろしくやってくれるらしい
266
+ if args.train_text_encoder:
267
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
268
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
269
+ )
270
+ else:
271
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
272
+
273
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
274
+ if args.full_fp16:
275
+ train_util.patch_accelerator_for_fp16_training(accelerator)
276
+
277
+ # resumeする
278
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
279
+
280
+ # epoch数を計算する
281
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
282
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
283
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
284
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
285
+
286
+ # 学習する
287
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
288
+ accelerator.print("running training / 学習開始")
289
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
290
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
291
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
292
+ accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
293
+ accelerator.print(
294
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
295
+ )
296
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
297
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
298
+
299
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
300
+ global_step = 0
301
+
302
+ noise_scheduler = DDPMScheduler(
303
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
304
+ )
305
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
306
+ if args.zero_terminal_snr:
307
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
308
+
309
+ if accelerator.is_main_process:
310
+ init_kwargs = {}
311
+ if args.wandb_run_name:
312
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
313
+ if args.log_tracker_config is not None:
314
+ init_kwargs = toml.load(args.log_tracker_config)
315
+ accelerator.init_trackers(
316
+ "finetuning" if args.log_tracker_name is None else args.log_tracker_name,
317
+ config=train_util.get_sanitized_config_or_none(args),
318
+ init_kwargs=init_kwargs,
319
+ )
320
+
321
+ # For --sample_at_first
322
+ train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
323
+
324
+ loss_recorder = train_util.LossRecorder()
325
+ for epoch in range(num_train_epochs):
326
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
327
+ current_epoch.value = epoch + 1
328
+
329
+ for m in training_models:
330
+ m.train()
331
+
332
+ for step, batch in enumerate(train_dataloader):
333
+ current_step.value = global_step
334
+ with accelerator.accumulate(*training_models):
335
+ with torch.no_grad():
336
+ if "latents" in batch and batch["latents"] is not None:
337
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
338
+ else:
339
+ # latentに変換
340
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
341
+ latents = latents * 0.18215
342
+ b_size = latents.shape[0]
343
+
344
+ with torch.set_grad_enabled(args.train_text_encoder):
345
+ # Get the text embedding for conditioning
346
+ if args.weighted_captions:
347
+ encoder_hidden_states = get_weighted_text_embeddings(
348
+ tokenizer,
349
+ text_encoder,
350
+ batch["captions"],
351
+ accelerator.device,
352
+ args.max_token_length // 75 if args.max_token_length else 1,
353
+ clip_skip=args.clip_skip,
354
+ )
355
+ else:
356
+ input_ids = batch["input_ids"].to(accelerator.device)
357
+ encoder_hidden_states = train_util.get_hidden_states(
358
+ args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
359
+ )
360
+
361
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
362
+ # with noise offset and/or multires noise if specified
363
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
364
+ args, noise_scheduler, latents
365
+ )
366
+
367
+ # Predict the noise residual
368
+ with accelerator.autocast():
369
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
370
+
371
+ if args.v_parameterization:
372
+ # v-parameterization training
373
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
374
+ else:
375
+ target = noise
376
+
377
+ if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
378
+ # do not mean over batch dimension for snr weight or scale v-pred loss
379
+ loss = train_util.conditional_loss(
380
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
381
+ )
382
+ loss = loss.mean([1, 2, 3])
383
+
384
+ if args.min_snr_gamma:
385
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
386
+ if args.scale_v_pred_loss_like_noise_pred:
387
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
388
+ if args.debiased_estimation_loss:
389
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
390
+
391
+ loss = loss.mean() # mean over batch dimension
392
+ else:
393
+ loss = train_util.conditional_loss(
394
+ noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
395
+ )
396
+
397
+ accelerator.backward(loss)
398
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
399
+ params_to_clip = []
400
+ for m in training_models:
401
+ params_to_clip.extend(m.parameters())
402
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
403
+
404
+ optimizer.step()
405
+ lr_scheduler.step()
406
+ optimizer.zero_grad(set_to_none=True)
407
+
408
+ # Checks if the accelerator has performed an optimization step behind the scenes
409
+ if accelerator.sync_gradients:
410
+ progress_bar.update(1)
411
+ global_step += 1
412
+
413
+ train_util.sample_images(
414
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
415
+ )
416
+
417
+ # 指定ステップごとにモデルを保存
418
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
419
+ accelerator.wait_for_everyone()
420
+ if accelerator.is_main_process:
421
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
422
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
423
+ args,
424
+ False,
425
+ accelerator,
426
+ src_path,
427
+ save_stable_diffusion_format,
428
+ use_safetensors,
429
+ save_dtype,
430
+ epoch,
431
+ num_train_epochs,
432
+ global_step,
433
+ accelerator.unwrap_model(text_encoder),
434
+ accelerator.unwrap_model(unet),
435
+ vae,
436
+ )
437
+
438
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
439
+ if args.logging_dir is not None:
440
+ logs = {"loss": current_loss}
441
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
442
+ accelerator.log(logs, step=global_step)
443
+
444
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
445
+ avr_loss: float = loss_recorder.moving_average
446
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
447
+ progress_bar.set_postfix(**logs)
448
+
449
+ if global_step >= args.max_train_steps:
450
+ break
451
+
452
+ if args.logging_dir is not None:
453
+ logs = {"loss/epoch": loss_recorder.moving_average}
454
+ accelerator.log(logs, step=epoch + 1)
455
+
456
+ accelerator.wait_for_everyone()
457
+
458
+ if args.save_every_n_epochs is not None:
459
+ if accelerator.is_main_process:
460
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
461
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
462
+ args,
463
+ True,
464
+ accelerator,
465
+ src_path,
466
+ save_stable_diffusion_format,
467
+ use_safetensors,
468
+ save_dtype,
469
+ epoch,
470
+ num_train_epochs,
471
+ global_step,
472
+ accelerator.unwrap_model(text_encoder),
473
+ accelerator.unwrap_model(unet),
474
+ vae,
475
+ )
476
+
477
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
478
+
479
+ is_main_process = accelerator.is_main_process
480
+ if is_main_process:
481
+ unet = accelerator.unwrap_model(unet)
482
+ text_encoder = accelerator.unwrap_model(text_encoder)
483
+
484
+ accelerator.end_training()
485
+
486
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
487
+ train_util.save_state_on_train_end(args, accelerator)
488
+
489
+ del accelerator # この後メモリを使うのでこれは消す
490
+
491
+ if is_main_process:
492
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
493
+ train_util.save_sd_model_on_train_end(
494
+ args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
495
+ )
496
+ logger.info("model saved.")
497
+
498
+
499
+ def setup_parser() -> argparse.ArgumentParser:
500
+ parser = argparse.ArgumentParser()
501
+
502
+ add_logging_arguments(parser)
503
+ train_util.add_sd_models_arguments(parser)
504
+ train_util.add_dataset_arguments(parser, False, True, True)
505
+ train_util.add_training_arguments(parser, False)
506
+ deepspeed_utils.add_deepspeed_arguments(parser)
507
+ train_util.add_sd_saving_arguments(parser)
508
+ train_util.add_optimizer_arguments(parser)
509
+ config_util.add_config_arguments(parser)
510
+ custom_train_functions.add_custom_train_arguments(parser)
511
+
512
+ parser.add_argument(
513
+ "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
514
+ )
515
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
516
+ parser.add_argument(
517
+ "--learning_rate_te",
518
+ type=float,
519
+ default=None,
520
+ help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
521
+ )
522
+ parser.add_argument(
523
+ "--no_half_vae",
524
+ action="store_true",
525
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
526
+ )
527
+
528
+ return parser
529
+
530
+
531
+ if __name__ == "__main__":
532
+ parser = setup_parser()
533
+
534
+ args = parser.parse_args()
535
+ train_util.verify_command_line_training_args(args)
536
+ args = train_util.read_config_from_file(args, parser)
537
+
538
+ train(args)
gen_img.py ADDED
The diff for this file is too large to render. See raw diff
 
gen_img_diffusers.py ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.30.0
2
+ transformers==4.44.0
3
+ diffusers[torch]==0.25.0
4
+ ftfy==6.1.1
5
+ # albumentations==1.3.0
6
+ opencv-python==4.8.1.78
7
+ einops==0.7.0
8
+ pytorch-lightning==1.9.0
9
+ bitsandbytes==0.44.0
10
+ prodigyopt==1.0
11
+ lion-pytorch==0.0.6
12
+ tensorboard
13
+ safetensors==0.4.2
14
+ # gradio==3.16.2
15
+ altair==4.2.2
16
+ easygui==0.98.3
17
+ toml==0.10.2
18
+ voluptuous==0.13.1
19
+ huggingface-hub==0.24.5
20
+ # for Image utils
21
+ imagesize==1.4.1
22
+ # for BLIP captioning
23
+ # requests==2.28.2
24
+ # timm==0.6.12
25
+ # fairscale==0.4.13
26
+ # for WD14 captioning (tensorflow)
27
+ # tensorflow==2.10.1
28
+ # for WD14 captioning (onnx)
29
+ # onnx==1.15.0
30
+ # onnxruntime-gpu==1.17.1
31
+ # onnxruntime==1.17.1
32
+ # for cuda 12.1(default 11.8)
33
+ # onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
34
+
35
+ # this is for onnx:
36
+ # protobuf==3.20.3
37
+ # open clip for SDXL
38
+ # open-clip-torch==2.20.0
39
+ # For logging
40
+ rich==13.7.0
41
+ # for kohya_ss library
42
+ -e .
sdxl_gen_img.py ADDED
The diff for this file is too large to render. See raw diff
 
sdxl_minimal_inference.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 手元で推論を行うための最低限のコード。HuggingFace/DiffusersのCLIP、schedulerとVAEを使う
2
+ # Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
3
+
4
+ import argparse
5
+ import datetime
6
+ import math
7
+ import os
8
+ import random
9
+ from einops import repeat
10
+ import numpy as np
11
+
12
+ import torch
13
+ from library.device_utils import init_ipex, get_preferred_device
14
+
15
+ init_ipex()
16
+
17
+ from tqdm import tqdm
18
+ from transformers import CLIPTokenizer
19
+ from diffusers import EulerDiscreteScheduler
20
+ from PIL import Image
21
+
22
+ # import open_clip
23
+ from safetensors.torch import load_file
24
+
25
+ from library import model_util, sdxl_model_util
26
+ import networks.lora as lora
27
+ from library.utils import setup_logging
28
+
29
+ setup_logging()
30
+ import logging
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # scheduler: このあたりの設定はSD1/2と同じでいいらしい
35
+ # scheduler: The settings around here seem to be the same as SD1/2
36
+ SCHEDULER_LINEAR_START = 0.00085
37
+ SCHEDULER_LINEAR_END = 0.0120
38
+ SCHEDULER_TIMESTEPS = 1000
39
+ SCHEDLER_SCHEDULE = "scaled_linear"
40
+
41
+
42
+ # Time EmbeddingはDiffusersからのコピー
43
+ # Time Embedding is copied from Diffusers
44
+
45
+
46
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
47
+ """
48
+ Create sinusoidal timestep embeddings.
49
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
50
+ These may be fractional.
51
+ :param dim: the dimension of the output.
52
+ :param max_period: controls the minimum frequency of the embeddings.
53
+ :return: an [N x dim] Tensor of positional embeddings.
54
+ """
55
+ if not repeat_only:
56
+ half = dim // 2
57
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
58
+ device=timesteps.device
59
+ )
60
+ args = timesteps[:, None].float() * freqs[None]
61
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
62
+ if dim % 2:
63
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ return embedding
67
+
68
+
69
+ def get_timestep_embedding(x, outdim):
70
+ assert len(x.shape) == 2
71
+ b, dims = x.shape[0], x.shape[1]
72
+ # x = rearrange(x, "b d -> (b d)")
73
+ x = torch.flatten(x)
74
+ emb = timestep_embedding(x, outdim)
75
+ # emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
76
+ emb = torch.reshape(emb, (b, dims * outdim))
77
+ return emb
78
+
79
+
80
+ if __name__ == "__main__":
81
+ # 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
82
+
83
+ # SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
84
+ target_height = 1024
85
+ target_width = 1024
86
+ original_height = target_height
87
+ original_width = target_width
88
+ crop_top = 0
89
+ crop_left = 0
90
+
91
+ steps = 50
92
+ guidance_scale = 7
93
+ seed = None # 1
94
+
95
+ DEVICE = get_preferred_device()
96
+ DTYPE = torch.float16 # bfloat16 may work
97
+
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument("--ckpt_path", type=str, required=True)
100
+ parser.add_argument("--prompt", type=str, default="A photo of a cat")
101
+ parser.add_argument("--prompt2", type=str, default=None)
102
+ parser.add_argument("--negative_prompt", type=str, default="")
103
+ parser.add_argument("--output_dir", type=str, default=".")
104
+ parser.add_argument(
105
+ "--lora_weights",
106
+ type=str,
107
+ nargs="*",
108
+ default=[],
109
+ help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
110
+ )
111
+ parser.add_argument("--interactive", action="store_true")
112
+ args = parser.parse_args()
113
+
114
+ if args.prompt2 is None:
115
+ args.prompt2 = args.prompt
116
+
117
+ # HuggingFaceのmodel id
118
+ text_encoder_1_name = "openai/clip-vit-large-patch14"
119
+ text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
120
+
121
+ # checkpointを読み込む。モデル変換についてはそちらの関数を参照
122
+ # Load checkpoint. For model conversion, see this function
123
+
124
+ # 本体RAMが少ない場合はGPUにロードするといいかも
125
+ # If the main RAM is small, it may be better to load it on the GPU
126
+ text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
127
+ sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
128
+ )
129
+
130
+ # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
131
+ # In SDXL, Text Encoder 1 is also using HuggingFace's
132
+
133
+ # Text Encoder 2はSDXL本体ではopen_clipを使っている
134
+ # それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
135
+ # 重みの変換コードはSD2とほぼ同じ
136
+ # In SDXL, Text Encoder 2 is using open_clip
137
+ # It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
138
+ # The weight conversion code is almost the same as SD2
139
+
140
+ # VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
141
+ # fp16でNaNが出やすいようだ
142
+ # The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
143
+ # NaN seems to be more likely to occur in fp16
144
+
145
+ unet.to(DEVICE, dtype=DTYPE)
146
+ unet.eval()
147
+
148
+ vae_dtype = DTYPE
149
+ if DTYPE == torch.float16:
150
+ logger.info("use float32 for vae")
151
+ vae_dtype = torch.float32
152
+ vae.to(DEVICE, dtype=vae_dtype)
153
+ vae.eval()
154
+
155
+ text_model1.to(DEVICE, dtype=DTYPE)
156
+ text_model1.eval()
157
+ text_model2.to(DEVICE, dtype=DTYPE)
158
+ text_model2.eval()
159
+
160
+ unet.set_use_memory_efficient_attention(True, False)
161
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
162
+ vae.set_use_memory_efficient_attention_xformers(True)
163
+
164
+ # Tokenizers
165
+ tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
166
+ # tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
167
+ tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
168
+
169
+ # LoRA
170
+ for weights_file in args.lora_weights:
171
+ if ";" in weights_file:
172
+ weights_file, multiplier = weights_file.split(";")
173
+ multiplier = float(multiplier)
174
+ else:
175
+ multiplier = 1.0
176
+
177
+ lora_model, weights_sd = lora.create_network_from_weights(
178
+ multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
179
+ )
180
+ lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
181
+
182
+ # scheduler
183
+ scheduler = EulerDiscreteScheduler(
184
+ num_train_timesteps=SCHEDULER_TIMESTEPS,
185
+ beta_start=SCHEDULER_LINEAR_START,
186
+ beta_end=SCHEDULER_LINEAR_END,
187
+ beta_schedule=SCHEDLER_SCHEDULE,
188
+ )
189
+
190
+ def generate_image(prompt, prompt2, negative_prompt, seed=None):
191
+ # 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
192
+ # prepare embedding
193
+ with torch.no_grad():
194
+ # vector
195
+ emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
196
+ emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
197
+ emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
198
+ # logger.info("emb1", emb1.shape)
199
+ c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
200
+ uc_vector = c_vector.clone().to(
201
+ DEVICE, dtype=DTYPE
202
+ ) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
203
+
204
+ # crossattn
205
+
206
+ # Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
207
+ def call_text_encoder(text, text2):
208
+ # text encoder 1
209
+ batch_encoding = tokenizer1(
210
+ text,
211
+ truncation=True,
212
+ return_length=True,
213
+ return_overflowing_tokens=False,
214
+ padding="max_length",
215
+ return_tensors="pt",
216
+ )
217
+ tokens = batch_encoding["input_ids"].to(DEVICE)
218
+
219
+ with torch.no_grad():
220
+ enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
221
+ text_embedding1 = enc_out["hidden_states"][11]
222
+ # text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
223
+
224
+ # text encoder 2
225
+ # tokens = tokenizer2(text2).to(DEVICE)
226
+ tokens = tokenizer2(
227
+ text,
228
+ truncation=True,
229
+ return_length=True,
230
+ return_overflowing_tokens=False,
231
+ padding="max_length",
232
+ return_tensors="pt",
233
+ )
234
+ tokens = batch_encoding["input_ids"].to(DEVICE)
235
+
236
+ with torch.no_grad():
237
+ enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
238
+ text_embedding2_penu = enc_out["hidden_states"][-2]
239
+ # logger.info("hidden_states2", text_embedding2_penu.shape)
240
+ text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
241
+
242
+ # 連結して終了 concat and finish
243
+ text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
244
+ return text_embedding, text_embedding2_pool
245
+
246
+ # cond
247
+ c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
248
+ # logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
249
+ c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
250
+
251
+ # uncond
252
+ uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
253
+ uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
254
+
255
+ text_embeddings = torch.cat([uc_ctx, c_ctx])
256
+ vector_embeddings = torch.cat([uc_vector, c_vector])
257
+
258
+ # メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
259
+
260
+ if seed is not None:
261
+ random.seed(seed)
262
+ np.random.seed(seed)
263
+ torch.manual_seed(seed)
264
+ torch.cuda.manual_seed_all(seed)
265
+
266
+ # # random generator for initial noise
267
+ # generator = torch.Generator(device="cuda").manual_seed(seed)
268
+ generator = None
269
+ else:
270
+ generator = None
271
+
272
+ # get the initial random noise unless the user supplied it
273
+ # SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
274
+ # SDXL creates latents in CPU, Diffusers creates latents in target device
275
+ latents_shape = (1, 4, target_height // 8, target_width // 8)
276
+ latents = torch.randn(
277
+ latents_shape,
278
+ generator=generator,
279
+ device="cpu",
280
+ dtype=torch.float32,
281
+ ).to(DEVICE, dtype=DTYPE)
282
+
283
+ # scale the initial noise by the standard deviation required by the scheduler
284
+ latents = latents * scheduler.init_noise_sigma
285
+
286
+ # set timesteps
287
+ scheduler.set_timesteps(steps, DEVICE)
288
+
289
+ # このへんはDiffusersからのコピペ
290
+ # Copy from Diffusers
291
+ timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
292
+ num_latent_input = 2
293
+ with torch.no_grad():
294
+ for i, t in enumerate(tqdm(timesteps)):
295
+ # expand the latents if we are doing classifier free guidance
296
+ latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
297
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
298
+
299
+ noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
300
+
301
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
302
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
303
+
304
+ # compute the previous noisy sample x_t -> x_t-1
305
+ # latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
306
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
307
+
308
+ # latents = 1 / 0.18215 * latents
309
+ latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
310
+ latents = latents.to(vae_dtype)
311
+ image = vae.decode(latents).sample
312
+ image = (image / 2 + 0.5).clamp(0, 1)
313
+
314
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
315
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
316
+
317
+ # image = self.numpy_to_pil(image)
318
+ image = (image * 255).round().astype("uint8")
319
+ image = [Image.fromarray(im) for im in image]
320
+
321
+ # 保存して終了 save and finish
322
+ timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
323
+ for i, img in enumerate(image):
324
+ img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
325
+
326
+ if not args.interactive:
327
+ generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
328
+ else:
329
+ # loop for interactive
330
+ while True:
331
+ prompt = input("prompt: ")
332
+ if prompt == "":
333
+ break
334
+ prompt2 = input("prompt2: ")
335
+ if prompt2 == "":
336
+ prompt2 = prompt
337
+ negative_prompt = input("negative prompt: ")
338
+ seed = input("seed: ")
339
+ if seed == "":
340
+ seed = None
341
+ else:
342
+ seed = int(seed)
343
+ generate_image(prompt, prompt2, negative_prompt, seed)
344
+
345
+ logger.info("Done!")
sdxl_train.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training with captions
2
+
3
+ import argparse
4
+ import math
5
+ import os
6
+ from multiprocessing import Value
7
+ from typing import List
8
+ import toml
9
+
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from library.device_utils import init_ipex, clean_memory_on_device
14
+
15
+
16
+ init_ipex()
17
+
18
+ from accelerate.utils import set_seed
19
+ from diffusers import DDPMScheduler
20
+ from library import deepspeed_utils, sdxl_model_util
21
+
22
+ import library.train_util as train_util
23
+
24
+ from library.utils import setup_logging, add_logging_arguments
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ import library.config_util as config_util
32
+ import library.sdxl_train_util as sdxl_train_util
33
+ from library.config_util import (
34
+ ConfigSanitizer,
35
+ BlueprintGenerator,
36
+ )
37
+ import library.custom_train_functions as custom_train_functions
38
+ from library.custom_train_functions import (
39
+ apply_snr_weight,
40
+ prepare_scheduler_for_custom_training,
41
+ scale_v_prediction_loss_like_noise_prediction,
42
+ add_v_prediction_like_loss,
43
+ apply_debiased_estimation,
44
+ apply_masked_loss,
45
+ )
46
+ from library.sdxl_original_unet import SdxlUNet2DConditionModel
47
+
48
+
49
+ UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
50
+
51
+
52
+ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
53
+ block_params = [[] for _ in range(len(block_lrs))]
54
+
55
+ for i, (name, param) in enumerate(unet.named_parameters()):
56
+ if name.startswith("time_embed.") or name.startswith("label_emb."):
57
+ block_index = 0 # 0
58
+ elif name.startswith("input_blocks."): # 1-9
59
+ block_index = 1 + int(name.split(".")[1])
60
+ elif name.startswith("middle_block."): # 10-12
61
+ block_index = 10 + int(name.split(".")[1])
62
+ elif name.startswith("output_blocks."): # 13-21
63
+ block_index = 13 + int(name.split(".")[1])
64
+ elif name.startswith("out."): # 22
65
+ block_index = 22
66
+ else:
67
+ raise ValueError(f"unexpected parameter name: {name}")
68
+
69
+ block_params[block_index].append(param)
70
+
71
+ params_to_optimize = []
72
+ for i, params in enumerate(block_params):
73
+ if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
74
+ continue
75
+ params_to_optimize.append({"params": params, "lr": block_lrs[i]})
76
+
77
+ return params_to_optimize
78
+
79
+
80
+ def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
81
+ names = []
82
+ block_index = 0
83
+ while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
84
+ if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
85
+ if block_lrs[block_index] == 0:
86
+ block_index += 1
87
+ continue
88
+ names.append(f"block{block_index}")
89
+ elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
90
+ names.append("text_encoder1")
91
+ elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
92
+ names.append("text_encoder2")
93
+
94
+ block_index += 1
95
+
96
+ train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
97
+
98
+
99
+ def train(args):
100
+ train_util.verify_training_args(args)
101
+ train_util.prepare_dataset_args(args, True)
102
+ sdxl_train_util.verify_sdxl_training_args(args)
103
+ deepspeed_utils.prepare_deepspeed_args(args)
104
+ setup_logging(args, reset=True)
105
+
106
+ assert (
107
+ not args.weighted_captions
108
+ ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
109
+ assert (
110
+ not args.train_text_encoder or not args.cache_text_encoder_outputs
111
+ ), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
112
+
113
+ if args.block_lr:
114
+ block_lrs = [float(lr) for lr in args.block_lr.split(",")]
115
+ assert (
116
+ len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
117
+ ), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
118
+ else:
119
+ block_lrs = None
120
+
121
+ cache_latents = args.cache_latents
122
+ use_dreambooth_method = args.in_json is None
123
+
124
+ if args.seed is not None:
125
+ set_seed(args.seed) # 乱数系列を初期化する
126
+
127
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
128
+
129
+ # データセットを準備する
130
+ if args.dataset_class is None:
131
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
132
+ if args.dataset_config is not None:
133
+ logger.info(f"Load dataset config from {args.dataset_config}")
134
+ user_config = config_util.load_user_config(args.dataset_config)
135
+ ignored = ["train_data_dir", "in_json"]
136
+ if any(getattr(args, attr) is not None for attr in ignored):
137
+ logger.warning(
138
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無���されます: {0}".format(
139
+ ", ".join(ignored)
140
+ )
141
+ )
142
+ else:
143
+ if use_dreambooth_method:
144
+ logger.info("Using DreamBooth method.")
145
+ user_config = {
146
+ "datasets": [
147
+ {
148
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
149
+ args.train_data_dir, args.reg_data_dir
150
+ )
151
+ }
152
+ ]
153
+ }
154
+ else:
155
+ logger.info("Training with captions.")
156
+ user_config = {
157
+ "datasets": [
158
+ {
159
+ "subsets": [
160
+ {
161
+ "image_dir": args.train_data_dir,
162
+ "metadata_file": args.in_json,
163
+ }
164
+ ]
165
+ }
166
+ ]
167
+ }
168
+
169
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
170
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
171
+ else:
172
+ train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
173
+
174
+ current_epoch = Value("i", 0)
175
+ current_step = Value("i", 0)
176
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
177
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
178
+
179
+ train_dataset_group.verify_bucket_reso_steps(32)
180
+
181
+ if args.debug_dataset:
182
+ train_util.debug_dataset(train_dataset_group, True)
183
+ return
184
+ if len(train_dataset_group) == 0:
185
+ logger.error(
186
+ "No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
187
+ )
188
+ return
189
+
190
+ if cache_latents:
191
+ assert (
192
+ train_dataset_group.is_latent_cacheable()
193
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
194
+
195
+ if args.cache_text_encoder_outputs:
196
+ assert (
197
+ train_dataset_group.is_text_encoder_output_cacheable()
198
+ ), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
199
+
200
+ # acceleratorを準備する
201
+ logger.info("prepare accelerator")
202
+ accelerator = train_util.prepare_accelerator(args)
203
+
204
+ # mixed precisionに対応した型を用意しておき適宜castする
205
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
206
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
207
+
208
+ # モデルを読み込む
209
+ (
210
+ load_stable_diffusion_format,
211
+ text_encoder1,
212
+ text_encoder2,
213
+ vae,
214
+ unet,
215
+ logit_scale,
216
+ ckpt_info,
217
+ ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
218
+ # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
219
+
220
+ # verify load/save model formats
221
+ if load_stable_diffusion_format:
222
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
223
+ src_diffusers_model_path = None
224
+ else:
225
+ src_stable_diffusion_ckpt = None
226
+ src_diffusers_model_path = args.pretrained_model_name_or_path
227
+
228
+ if args.save_model_as is None:
229
+ save_stable_diffusion_format = load_stable_diffusion_format
230
+ use_safetensors = args.use_safetensors
231
+ else:
232
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
233
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
234
+ # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
235
+
236
+ # Diffusers版のxformers使用フラグを設定する関数
237
+ def set_diffusers_xformers_flag(model, valid):
238
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
239
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
240
+ module.set_use_memory_efficient_attention_xformers(valid)
241
+
242
+ for child in module.children():
243
+ fn_recursive_set_mem_eff(child)
244
+
245
+ fn_recursive_set_mem_eff(model)
246
+
247
+ # モデルに xformers とか memory efficient attention を組み込む
248
+ if args.diffusers_xformers:
249
+ # もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
250
+ accelerator.print("Use xformers by Diffusers")
251
+ # set_diffusers_xformers_flag(unet, True)
252
+ set_diffusers_xformers_flag(vae, True)
253
+ else:
254
+ # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
255
+ accelerator.print("Disable Diffusers' xformers")
256
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
257
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
258
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
259
+
260
+ # 学習を準備する
261
+ if cache_latents:
262
+ vae.to(accelerator.device, dtype=vae_dtype)
263
+ vae.requires_grad_(False)
264
+ vae.eval()
265
+ with torch.no_grad():
266
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
267
+ vae.to("cpu")
268
+ clean_memory_on_device(accelerator.device)
269
+
270
+ accelerator.wait_for_everyone()
271
+
272
+ # 学習を準備する:モデルを適切な状態にする
273
+ if args.gradient_checkpointing:
274
+ unet.enable_gradient_checkpointing()
275
+ train_unet = args.learning_rate != 0
276
+ train_text_encoder1 = False
277
+ train_text_encoder2 = False
278
+
279
+ if args.train_text_encoder:
280
+ # TODO each option for two text encoders?
281
+ accelerator.print("enable text encoder training")
282
+ if args.gradient_checkpointing:
283
+ text_encoder1.gradient_checkpointing_enable()
284
+ text_encoder2.gradient_checkpointing_enable()
285
+ lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
286
+ lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
287
+ train_text_encoder1 = lr_te1 != 0
288
+ train_text_encoder2 = lr_te2 != 0
289
+
290
+ # caching one text encoder output is not supported
291
+ if not train_text_encoder1:
292
+ text_encoder1.to(weight_dtype)
293
+ if not train_text_encoder2:
294
+ text_encoder2.to(weight_dtype)
295
+ text_encoder1.requires_grad_(train_text_encoder1)
296
+ text_encoder2.requires_grad_(train_text_encoder2)
297
+ text_encoder1.train(train_text_encoder1)
298
+ text_encoder2.train(train_text_encoder2)
299
+ else:
300
+ text_encoder1.to(weight_dtype)
301
+ text_encoder2.to(weight_dtype)
302
+ text_encoder1.requires_grad_(False)
303
+ text_encoder2.requires_grad_(False)
304
+ text_encoder1.eval()
305
+ text_encoder2.eval()
306
+
307
+ # TextEncoderの出力をキャッシュする
308
+ if args.cache_text_encoder_outputs:
309
+ # Text Encodes are eval and no grad
310
+ with torch.no_grad(), accelerator.autocast():
311
+ train_dataset_group.cache_text_encoder_outputs(
312
+ (tokenizer1, tokenizer2),
313
+ (text_encoder1, text_encoder2),
314
+ accelerator.device,
315
+ None,
316
+ args.cache_text_encoder_outputs_to_disk,
317
+ accelerator.is_main_process,
318
+ )
319
+ accelerator.wait_for_everyone()
320
+
321
+ if not cache_latents:
322
+ vae.requires_grad_(False)
323
+ vae.eval()
324
+ vae.to(accelerator.device, dtype=vae_dtype)
325
+
326
+ unet.requires_grad_(train_unet)
327
+ if not train_unet:
328
+ unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
329
+
330
+ training_models = []
331
+ params_to_optimize = []
332
+ if train_unet:
333
+ training_models.append(unet)
334
+ if block_lrs is None:
335
+ params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
336
+ else:
337
+ params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
338
+
339
+ if train_text_encoder1:
340
+ training_models.append(text_encoder1)
341
+ params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
342
+ if train_text_encoder2:
343
+ training_models.append(text_encoder2)
344
+ params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
345
+
346
+ # calculate number of trainable parameters
347
+ n_params = 0
348
+ for group in params_to_optimize:
349
+ for p in group["params"]:
350
+ n_params += p.numel()
351
+
352
+ accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
353
+ accelerator.print(f"number of models: {len(training_models)}")
354
+ accelerator.print(f"number of trainable parameters: {n_params}")
355
+
356
+ # 学習に必要なクラスを準備する
357
+ accelerator.print("prepare optimizer, data loader etc.")
358
+
359
+ if args.fused_optimizer_groups:
360
+ # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
361
+ # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
362
+ # This balances memory usage and management complexity.
363
+
364
+ # calculate total number of parameters
365
+ n_total_params = sum(len(params["params"]) for params in params_to_optimize)
366
+ params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
367
+
368
+ # split params into groups, keeping the learning rate the same for all params in a group
369
+ # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
370
+ grouped_params = []
371
+ param_group = []
372
+ param_group_lr = -1
373
+ for group in params_to_optimize:
374
+ lr = group["lr"]
375
+ for p in group["params"]:
376
+ # if the learning rate is different for different params, start a new group
377
+ if lr != param_group_lr:
378
+ if param_group:
379
+ grouped_params.append({"params": param_group, "lr": param_group_lr})
380
+ param_group = []
381
+ param_group_lr = lr
382
+
383
+ param_group.append(p)
384
+
385
+ # if the group has enough parameters, start a new group
386
+ if len(param_group) == params_per_group:
387
+ grouped_params.append({"params": param_group, "lr": param_group_lr})
388
+ param_group = []
389
+ param_group_lr = -1
390
+
391
+ if param_group:
392
+ grouped_params.append({"params": param_group, "lr": param_group_lr})
393
+
394
+ # prepare optimizers for each group
395
+ optimizers = []
396
+ for group in grouped_params:
397
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
398
+ optimizers.append(optimizer)
399
+ optimizer = optimizers[0] # avoid error in the following code
400
+
401
+ logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
402
+
403
+ else:
404
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
405
+
406
+ # dataloaderを準備する
407
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
408
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
409
+ train_dataloader = torch.utils.data.DataLoader(
410
+ train_dataset_group,
411
+ batch_size=1,
412
+ shuffle=True,
413
+ collate_fn=collator,
414
+ num_workers=n_workers,
415
+ persistent_workers=args.persistent_data_loader_workers,
416
+ )
417
+
418
+ # 学習ステップ数を計算する
419
+ if args.max_train_epochs is not None:
420
+ args.max_train_steps = args.max_train_epochs * math.ceil(
421
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
422
+ )
423
+ accelerator.print(
424
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
425
+ )
426
+
427
+ # データセット側にも学習ステップを送信
428
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
429
+
430
+ # lr schedulerを用意する
431
+ if args.fused_optimizer_groups:
432
+ # prepare lr schedulers for each optimizer
433
+ lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
434
+ lr_scheduler = lr_schedulers[0] # avoid error in the following code
435
+ else:
436
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
437
+
438
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
439
+ if args.full_fp16:
440
+ assert (
441
+ args.mixed_precision == "fp16"
442
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
443
+ accelerator.print("enable full fp16 training.")
444
+ unet.to(weight_dtype)
445
+ text_encoder1.to(weight_dtype)
446
+ text_encoder2.to(weight_dtype)
447
+ elif args.full_bf16:
448
+ assert (
449
+ args.mixed_precision == "bf16"
450
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
451
+ accelerator.print("enable full bf16 training.")
452
+ unet.to(weight_dtype)
453
+ text_encoder1.to(weight_dtype)
454
+ text_encoder2.to(weight_dtype)
455
+
456
+ # freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
457
+ if train_text_encoder1:
458
+ text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
459
+ text_encoder1.text_model.final_layer_norm.requires_grad_(False)
460
+
461
+ if args.deepspeed:
462
+ ds_model = deepspeed_utils.prepare_deepspeed_model(
463
+ args,
464
+ unet=unet if train_unet else None,
465
+ text_encoder1=text_encoder1 if train_text_encoder1 else None,
466
+ text_encoder2=text_encoder2 if train_text_encoder2 else None,
467
+ )
468
+ # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
469
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
470
+ ds_model, optimizer, train_dataloader, lr_scheduler
471
+ )
472
+ training_models = [ds_model]
473
+
474
+ else:
475
+ # acceleratorがなんかよろしくやってくれるらしい
476
+ if train_unet:
477
+ unet = accelerator.prepare(unet)
478
+ if train_text_encoder1:
479
+ text_encoder1 = accelerator.prepare(text_encoder1)
480
+ if train_text_encoder2:
481
+ text_encoder2 = accelerator.prepare(text_encoder2)
482
+ optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
483
+
484
+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
485
+ if args.cache_text_encoder_outputs:
486
+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
487
+ text_encoder1.to("cpu", dtype=torch.float32)
488
+ text_encoder2.to("cpu", dtype=torch.float32)
489
+ clean_memory_on_device(accelerator.device)
490
+ else:
491
+ # make sure Text Encoders are on GPU
492
+ text_encoder1.to(accelerator.device)
493
+ text_encoder2.to(accelerator.device)
494
+
495
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
496
+ if args.full_fp16:
497
+ # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
498
+ # -> But we think it's ok to patch accelerator even if deepspeed is enabled.
499
+ train_util.patch_accelerator_for_fp16_training(accelerator)
500
+
501
+ # resumeする
502
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
503
+
504
+ if args.fused_backward_pass:
505
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
506
+ import library.adafactor_fused
507
+
508
+ library.adafactor_fused.patch_adafactor_fused(optimizer)
509
+ for param_group in optimizer.param_groups:
510
+ for parameter in param_group["params"]:
511
+ if parameter.requires_grad:
512
+
513
+ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
514
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
515
+ accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
516
+ optimizer.step_param(tensor, param_group)
517
+ tensor.grad = None
518
+
519
+ parameter.register_post_accumulate_grad_hook(__grad_hook)
520
+
521
+ elif args.fused_optimizer_groups:
522
+ # prepare for additional optimizers and lr schedulers
523
+ for i in range(1, len(optimizers)):
524
+ optimizers[i] = accelerator.prepare(optimizers[i])
525
+ lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
526
+
527
+ # counters are used to determine when to step the optimizer
528
+ global optimizer_hooked_count
529
+ global num_parameters_per_group
530
+ global parameter_optimizer_map
531
+
532
+ optimizer_hooked_count = {}
533
+ num_parameters_per_group = [0] * len(optimizers)
534
+ parameter_optimizer_map = {}
535
+
536
+ for opt_idx, optimizer in enumerate(optimizers):
537
+ for param_group in optimizer.param_groups:
538
+ for parameter in param_group["params"]:
539
+ if parameter.requires_grad:
540
+
541
+ def optimizer_hook(parameter: torch.Tensor):
542
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
543
+ accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
544
+
545
+ i = parameter_optimizer_map[parameter]
546
+ optimizer_hooked_count[i] += 1
547
+ if optimizer_hooked_count[i] == num_parameters_per_group[i]:
548
+ optimizers[i].step()
549
+ optimizers[i].zero_grad(set_to_none=True)
550
+
551
+ parameter.register_post_accumulate_grad_hook(optimizer_hook)
552
+ parameter_optimizer_map[parameter] = opt_idx
553
+ num_parameters_per_group[opt_idx] += 1
554
+
555
+ # epoch数を計算する
556
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
557
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
558
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
559
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
560
+
561
+ # 学習する
562
+ # total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
563
+ accelerator.print("running training / 学習開始")
564
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
565
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
566
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
567
+ accelerator.print(
568
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
569
+ )
570
+ # accelerator.print(
571
+ # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
572
+ # )
573
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
574
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
575
+
576
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
577
+ global_step = 0
578
+
579
+ noise_scheduler = DDPMScheduler(
580
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
581
+ )
582
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
583
+ if args.zero_terminal_snr:
584
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
585
+
586
+ if accelerator.is_main_process:
587
+ init_kwargs = {}
588
+ if args.wandb_run_name:
589
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
590
+ if args.log_tracker_config is not None:
591
+ init_kwargs = toml.load(args.log_tracker_config)
592
+ accelerator.init_trackers(
593
+ "finetuning" if args.log_tracker_name is None else args.log_tracker_name,
594
+ config=train_util.get_sanitized_config_or_none(args),
595
+ init_kwargs=init_kwargs,
596
+ )
597
+
598
+ # For --sample_at_first
599
+ sdxl_train_util.sample_images(
600
+ accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
601
+ )
602
+
603
+ loss_recorder = train_util.LossRecorder()
604
+ for epoch in range(num_train_epochs):
605
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
606
+ current_epoch.value = epoch + 1
607
+
608
+ for m in training_models:
609
+ m.train()
610
+
611
+ for step, batch in enumerate(train_dataloader):
612
+ current_step.value = global_step
613
+
614
+ if args.fused_optimizer_groups:
615
+ optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
616
+
617
+ with accelerator.accumulate(*training_models):
618
+ if "latents" in batch and batch["latents"] is not None:
619
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
620
+ else:
621
+ with torch.no_grad():
622
+ # latentに変換
623
+ latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
624
+
625
+ # NaNが含まれていれば警告を表示し0に置き換える
626
+ if torch.any(torch.isnan(latents)):
627
+ accelerator.print("NaN found in latents, replacing with zeros")
628
+ latents = torch.nan_to_num(latents, 0, out=latents)
629
+ latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
630
+
631
+ if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
632
+ input_ids1 = batch["input_ids"]
633
+ input_ids2 = batch["input_ids2"]
634
+ with torch.set_grad_enabled(args.train_text_encoder):
635
+ # Get the text embedding for conditioning
636
+ # TODO support weighted captions
637
+ # if args.weighted_captions:
638
+ # encoder_hidden_states = get_weighted_text_embeddings(
639
+ # tokenizer,
640
+ # text_encoder,
641
+ # batch["captions"],
642
+ # accelerator.device,
643
+ # args.max_token_length // 75 if args.max_token_length else 1,
644
+ # clip_skip=args.clip_skip,
645
+ # )
646
+ # else:
647
+ input_ids1 = input_ids1.to(accelerator.device)
648
+ input_ids2 = input_ids2.to(accelerator.device)
649
+ # unwrap_model is fine for models not wrapped by accelerator
650
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
651
+ args.max_token_length,
652
+ input_ids1,
653
+ input_ids2,
654
+ tokenizer1,
655
+ tokenizer2,
656
+ text_encoder1,
657
+ text_encoder2,
658
+ None if not args.full_fp16 else weight_dtype,
659
+ accelerator=accelerator,
660
+ )
661
+ else:
662
+ encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
663
+ encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
664
+ pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
665
+
666
+ # # verify that the text encoder outputs are correct
667
+ # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
668
+ # args.max_token_length,
669
+ # batch["input_ids"].to(text_encoder1.device),
670
+ # batch["input_ids2"].to(text_encoder1.device),
671
+ # tokenizer1,
672
+ # tokenizer2,
673
+ # text_encoder1,
674
+ # text_encoder2,
675
+ # None if not args.full_fp16 else weight_dtype,
676
+ # )
677
+ # b_size = encoder_hidden_states1.shape[0]
678
+ # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
679
+ # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
680
+ # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
681
+ # logger.info("text encoder outputs verified")
682
+
683
+ # get size embeddings
684
+ orig_size = batch["original_sizes_hw"]
685
+ crop_size = batch["crop_top_lefts"]
686
+ target_size = batch["target_sizes_hw"]
687
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
688
+
689
+ # concat embeddings
690
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
691
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
692
+
693
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
694
+ # with noise offset and/or multires noise if specified
695
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
696
+ args, noise_scheduler, latents
697
+ )
698
+
699
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
700
+
701
+ # Predict the noise residual
702
+ with accelerator.autocast():
703
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
704
+
705
+ if args.v_parameterization:
706
+ # v-parameterization training
707
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
708
+ else:
709
+ target = noise
710
+
711
+ if (
712
+ args.min_snr_gamma
713
+ or args.scale_v_pred_loss_like_noise_pred
714
+ or args.v_pred_like_loss
715
+ or args.debiased_estimation_loss
716
+ or args.masked_loss
717
+ ):
718
+ # do not mean over batch dimension for snr weight or scale v-pred loss
719
+ loss = train_util.conditional_loss(
720
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
721
+ )
722
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
723
+ loss = apply_masked_loss(loss, batch)
724
+ loss = loss.mean([1, 2, 3])
725
+
726
+ if args.min_snr_gamma:
727
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
728
+ if args.scale_v_pred_loss_like_noise_pred:
729
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
730
+ if args.v_pred_like_loss:
731
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
732
+ if args.debiased_estimation_loss:
733
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
734
+
735
+ loss = loss.mean() # mean over batch dimension
736
+ else:
737
+ loss = train_util.conditional_loss(
738
+ noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
739
+ )
740
+
741
+ accelerator.backward(loss)
742
+
743
+ if not (args.fused_backward_pass or args.fused_optimizer_groups):
744
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
745
+ params_to_clip = []
746
+ for m in training_models:
747
+ params_to_clip.extend(m.parameters())
748
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
749
+
750
+ optimizer.step()
751
+ lr_scheduler.step()
752
+ optimizer.zero_grad(set_to_none=True)
753
+ else:
754
+ # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
755
+ lr_scheduler.step()
756
+ if args.fused_optimizer_groups:
757
+ for i in range(1, len(optimizers)):
758
+ lr_schedulers[i].step()
759
+
760
+ # Checks if the accelerator has performed an optimization step behind the scenes
761
+ if accelerator.sync_gradients:
762
+ progress_bar.update(1)
763
+ global_step += 1
764
+
765
+ sdxl_train_util.sample_images(
766
+ accelerator,
767
+ args,
768
+ None,
769
+ global_step,
770
+ accelerator.device,
771
+ vae,
772
+ [tokenizer1, tokenizer2],
773
+ [text_encoder1, text_encoder2],
774
+ unet,
775
+ )
776
+
777
+ # 指定ステップごとにモデルを保存
778
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
779
+ accelerator.wait_for_everyone()
780
+ if accelerator.is_main_process:
781
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
782
+ sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
783
+ args,
784
+ False,
785
+ accelerator,
786
+ src_path,
787
+ save_stable_diffusion_format,
788
+ use_safetensors,
789
+ save_dtype,
790
+ epoch,
791
+ num_train_epochs,
792
+ global_step,
793
+ accelerator.unwrap_model(text_encoder1),
794
+ accelerator.unwrap_model(text_encoder2),
795
+ accelerator.unwrap_model(unet),
796
+ vae,
797
+ logit_scale,
798
+ ckpt_info,
799
+ )
800
+
801
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
802
+ if args.logging_dir is not None:
803
+ logs = {"loss": current_loss}
804
+ if block_lrs is None:
805
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
806
+ else:
807
+ append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
808
+
809
+ accelerator.log(logs, step=global_step)
810
+
811
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
812
+ avr_loss: float = loss_recorder.moving_average
813
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
814
+ progress_bar.set_postfix(**logs)
815
+
816
+ if global_step >= args.max_train_steps:
817
+ break
818
+
819
+ if args.logging_dir is not None:
820
+ logs = {"loss/epoch": loss_recorder.moving_average}
821
+ accelerator.log(logs, step=epoch + 1)
822
+
823
+ accelerator.wait_for_everyone()
824
+
825
+ if args.save_every_n_epochs is not None:
826
+ if accelerator.is_main_process:
827
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
828
+ sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
829
+ args,
830
+ True,
831
+ accelerator,
832
+ src_path,
833
+ save_stable_diffusion_format,
834
+ use_safetensors,
835
+ save_dtype,
836
+ epoch,
837
+ num_train_epochs,
838
+ global_step,
839
+ accelerator.unwrap_model(text_encoder1),
840
+ accelerator.unwrap_model(text_encoder2),
841
+ accelerator.unwrap_model(unet),
842
+ vae,
843
+ logit_scale,
844
+ ckpt_info,
845
+ )
846
+
847
+ sdxl_train_util.sample_images(
848
+ accelerator,
849
+ args,
850
+ epoch + 1,
851
+ global_step,
852
+ accelerator.device,
853
+ vae,
854
+ [tokenizer1, tokenizer2],
855
+ [text_encoder1, text_encoder2],
856
+ unet,
857
+ )
858
+
859
+ is_main_process = accelerator.is_main_process
860
+ # if is_main_process:
861
+ unet = accelerator.unwrap_model(unet)
862
+ text_encoder1 = accelerator.unwrap_model(text_encoder1)
863
+ text_encoder2 = accelerator.unwrap_model(text_encoder2)
864
+
865
+ accelerator.end_training()
866
+
867
+ if args.save_state or args.save_state_on_train_end:
868
+ train_util.save_state_on_train_end(args, accelerator)
869
+
870
+ del accelerator # この後メモリを使うのでこれは消す
871
+
872
+ if is_main_process:
873
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
874
+ sdxl_train_util.save_sd_model_on_train_end(
875
+ args,
876
+ src_path,
877
+ save_stable_diffusion_format,
878
+ use_safetensors,
879
+ save_dtype,
880
+ epoch,
881
+ global_step,
882
+ text_encoder1,
883
+ text_encoder2,
884
+ unet,
885
+ vae,
886
+ logit_scale,
887
+ ckpt_info,
888
+ )
889
+ logger.info("model saved.")
890
+
891
+
892
+ def setup_parser() -> argparse.ArgumentParser:
893
+ parser = argparse.ArgumentParser()
894
+
895
+ add_logging_arguments(parser)
896
+ train_util.add_sd_models_arguments(parser)
897
+ train_util.add_dataset_arguments(parser, True, True, True)
898
+ train_util.add_training_arguments(parser, False)
899
+ train_util.add_masked_loss_arguments(parser)
900
+ deepspeed_utils.add_deepspeed_arguments(parser)
901
+ train_util.add_sd_saving_arguments(parser)
902
+ train_util.add_optimizer_arguments(parser)
903
+ config_util.add_config_arguments(parser)
904
+ custom_train_functions.add_custom_train_arguments(parser)
905
+ sdxl_train_util.add_sdxl_training_arguments(parser)
906
+
907
+ parser.add_argument(
908
+ "--learning_rate_te1",
909
+ type=float,
910
+ default=None,
911
+ help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
912
+ )
913
+ parser.add_argument(
914
+ "--learning_rate_te2",
915
+ type=float,
916
+ default=None,
917
+ help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
918
+ )
919
+
920
+ parser.add_argument(
921
+ "--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
922
+ )
923
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
924
+ parser.add_argument(
925
+ "--no_half_vae",
926
+ action="store_true",
927
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
928
+ )
929
+ parser.add_argument(
930
+ "--block_lr",
931
+ type=str,
932
+ default=None,
933
+ help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
934
+ + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
935
+ )
936
+ parser.add_argument(
937
+ "--fused_optimizer_groups",
938
+ type=int,
939
+ default=None,
940
+ help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
941
+ )
942
+ return parser
943
+
944
+
945
+ if __name__ == "__main__":
946
+ parser = setup_parser()
947
+
948
+ args = parser.parse_args()
949
+ train_util.verify_command_line_training_args(args)
950
+ args = train_util.read_config_from_file(args, parser)
951
+
952
+ train(args)
sdxl_train_control_net_lllite.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
2
+ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward
3
+
4
+ import argparse
5
+ import json
6
+ import math
7
+ import os
8
+ import random
9
+ import time
10
+ from multiprocessing import Value
11
+ from types import SimpleNamespace
12
+ import toml
13
+
14
+ from tqdm import tqdm
15
+
16
+ import torch
17
+ from library.device_utils import init_ipex, clean_memory_on_device
18
+
19
+ init_ipex()
20
+
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from accelerate.utils import set_seed
23
+ import accelerate
24
+ from diffusers import DDPMScheduler, ControlNetModel
25
+ from safetensors.torch import load_file
26
+ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
27
+
28
+ import library.model_util as model_util
29
+ import library.train_util as train_util
30
+ import library.config_util as config_util
31
+ from library.config_util import (
32
+ ConfigSanitizer,
33
+ BlueprintGenerator,
34
+ )
35
+ import library.huggingface_util as huggingface_util
36
+ import library.custom_train_functions as custom_train_functions
37
+ from library.custom_train_functions import (
38
+ add_v_prediction_like_loss,
39
+ apply_snr_weight,
40
+ prepare_scheduler_for_custom_training,
41
+ pyramid_noise_like,
42
+ apply_noise_offset,
43
+ scale_v_prediction_loss_like_noise_prediction,
44
+ apply_debiased_estimation,
45
+ )
46
+ import networks.control_net_lllite_for_train as control_net_lllite_for_train
47
+ from library.utils import setup_logging, add_logging_arguments
48
+
49
+ setup_logging()
50
+ import logging
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ # TODO 他のスクリプトと共通化する
56
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
57
+ logs = {
58
+ "loss/current": current_loss,
59
+ "loss/average": avr_loss,
60
+ "lr": lr_scheduler.get_last_lr()[0],
61
+ }
62
+
63
+ if args.optimizer_type.lower().startswith("DAdapt".lower()):
64
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
65
+
66
+ return logs
67
+
68
+
69
+ def train(args):
70
+ train_util.verify_training_args(args)
71
+ train_util.prepare_dataset_args(args, True)
72
+ sdxl_train_util.verify_sdxl_training_args(args)
73
+ setup_logging(args, reset=True)
74
+
75
+ cache_latents = args.cache_latents
76
+ use_user_config = args.dataset_config is not None
77
+
78
+ if args.seed is None:
79
+ args.seed = random.randint(0, 2**32)
80
+ set_seed(args.seed)
81
+
82
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
83
+
84
+ # データセットを準備する
85
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
86
+ if use_user_config:
87
+ logger.info(f"Load dataset config from {args.dataset_config}")
88
+ user_config = config_util.load_user_config(args.dataset_config)
89
+ ignored = ["train_data_dir", "conditioning_data_dir"]
90
+ if any(getattr(args, attr) is not None for attr in ignored):
91
+ logger.warning(
92
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
93
+ ", ".join(ignored)
94
+ )
95
+ )
96
+ else:
97
+ user_config = {
98
+ "datasets": [
99
+ {
100
+ "subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
101
+ args.train_data_dir,
102
+ args.conditioning_data_dir,
103
+ args.caption_extension,
104
+ )
105
+ }
106
+ ]
107
+ }
108
+
109
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
110
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
111
+
112
+ current_epoch = Value("i", 0)
113
+ current_step = Value("i", 0)
114
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
115
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
116
+
117
+ train_dataset_group.verify_bucket_reso_steps(32)
118
+
119
+ if args.debug_dataset:
120
+ train_util.debug_dataset(train_dataset_group)
121
+ return
122
+ if len(train_dataset_group) == 0:
123
+ logger.error(
124
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
125
+ )
126
+ return
127
+
128
+ if cache_latents:
129
+ assert (
130
+ train_dataset_group.is_latent_cacheable()
131
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
132
+ else:
133
+ logger.warning(
134
+ "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
135
+ )
136
+
137
+ if args.cache_text_encoder_outputs:
138
+ assert (
139
+ train_dataset_group.is_text_encoder_output_cacheable()
140
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
141
+
142
+ # acceleratorを準備する
143
+ logger.info("prepare accelerator")
144
+ accelerator = train_util.prepare_accelerator(args)
145
+ is_main_process = accelerator.is_main_process
146
+
147
+ # mixed precisionに対応した型を用意しておき適宜castする
148
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
149
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
150
+
151
+ # モデルを読み込む
152
+ (
153
+ load_stable_diffusion_format,
154
+ text_encoder1,
155
+ text_encoder2,
156
+ vae,
157
+ unet,
158
+ logit_scale,
159
+ ckpt_info,
160
+ ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
161
+
162
+ # 学習を準備する
163
+ if cache_latents:
164
+ vae.to(accelerator.device, dtype=vae_dtype)
165
+ vae.requires_grad_(False)
166
+ vae.eval()
167
+ with torch.no_grad():
168
+ train_dataset_group.cache_latents(
169
+ vae,
170
+ args.vae_batch_size,
171
+ args.cache_latents_to_disk,
172
+ accelerator.is_main_process,
173
+ )
174
+ vae.to("cpu")
175
+ clean_memory_on_device(accelerator.device)
176
+
177
+ accelerator.wait_for_everyone()
178
+
179
+ # TextEncoderの出力をキャッシュする
180
+ if args.cache_text_encoder_outputs:
181
+ # Text Encodes are eval and no grad
182
+ with torch.no_grad():
183
+ train_dataset_group.cache_text_encoder_outputs(
184
+ (tokenizer1, tokenizer2),
185
+ (text_encoder1, text_encoder2),
186
+ accelerator.device,
187
+ None,
188
+ args.cache_text_encoder_outputs_to_disk,
189
+ accelerator.is_main_process,
190
+ )
191
+ accelerator.wait_for_everyone()
192
+
193
+ # prepare ControlNet-LLLite
194
+ control_net_lllite_for_train.replace_unet_linear_and_conv2d()
195
+
196
+ if args.network_weights is not None:
197
+ accelerator.print(f"initialize U-Net with ControlNet-LLLite")
198
+ with accelerate.init_empty_weights():
199
+ unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
200
+ unet_lllite.to(accelerator.device, dtype=weight_dtype)
201
+
202
+ unet_sd = unet.state_dict()
203
+ info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
204
+ accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
205
+ else:
206
+ # cosumes large memory, so send to GPU before creating the LLLite model
207
+ accelerator.print("sending U-Net to GPU")
208
+ unet.to(accelerator.device, dtype=weight_dtype)
209
+ unet_sd = unet.state_dict()
210
+
211
+ # init LLLite weights
212
+ accelerator.print(f"initialize U-Net with ControlNet-LLLite")
213
+
214
+ if args.lowram:
215
+ with accelerate.init_on_device(accelerator.device):
216
+ unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
217
+ else:
218
+ unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
219
+ unet_lllite.to(weight_dtype)
220
+
221
+ info = unet_lllite.load_lllite_weights(None, unet_sd)
222
+ accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
223
+ del unet_sd, unet
224
+
225
+ unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
226
+ del unet_lllite
227
+
228
+ unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
229
+
230
+ # モデルに xformers とか memory efficient attention を組み込む
231
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
232
+
233
+ if args.gradient_checkpointing:
234
+ unet.enable_gradient_checkpointing()
235
+
236
+ # 学習に必要なクラスを準備する
237
+ accelerator.print("prepare optimizer, data loader etc.")
238
+
239
+ trainable_params = list(unet.prepare_params())
240
+ logger.info(f"trainable params count: {len(trainable_params)}")
241
+ logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
242
+
243
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
244
+
245
+ # dataloaderを準備する
246
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
247
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
248
+
249
+ train_dataloader = torch.utils.data.DataLoader(
250
+ train_dataset_group,
251
+ batch_size=1,
252
+ shuffle=True,
253
+ collate_fn=collator,
254
+ num_workers=n_workers,
255
+ persistent_workers=args.persistent_data_loader_workers,
256
+ )
257
+
258
+ # 学習ステップ数を計算する
259
+ if args.max_train_epochs is not None:
260
+ args.max_train_steps = args.max_train_epochs * math.ceil(
261
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
262
+ )
263
+ accelerator.print(
264
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
265
+ )
266
+
267
+ # データセット側にも学習ステップを送信
268
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
269
+
270
+ # lr schedulerを用意する
271
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
272
+
273
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
274
+ # if args.full_fp16:
275
+ # assert (
276
+ # args.mixed_precision == "fp16"
277
+ # ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
278
+ # accelerator.print("enable full fp16 training.")
279
+ # unet.to(weight_dtype)
280
+ # elif args.full_bf16:
281
+ # assert (
282
+ # args.mixed_precision == "bf16"
283
+ # ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
284
+ # accelerator.print("enable full bf16 training.")
285
+ # unet.to(weight_dtype)
286
+
287
+ unet.to(weight_dtype)
288
+
289
+ # acceleratorがなんかよろしくやってくれるらしい
290
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
291
+
292
+ if isinstance(unet, DDP):
293
+ unet._set_static_graph() # avoid error for multiple use of the parameter
294
+
295
+ if args.gradient_checkpointing:
296
+ unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
297
+ else:
298
+ unet.eval()
299
+
300
+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
301
+ if args.cache_text_encoder_outputs:
302
+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
303
+ text_encoder1.to("cpu", dtype=torch.float32)
304
+ text_encoder2.to("cpu", dtype=torch.float32)
305
+ clean_memory_on_device(accelerator.device)
306
+ else:
307
+ # make sure Text Encoders are on GPU
308
+ text_encoder1.to(accelerator.device)
309
+ text_encoder2.to(accelerator.device)
310
+
311
+ if not cache_latents:
312
+ vae.requires_grad_(False)
313
+ vae.eval()
314
+ vae.to(accelerator.device, dtype=vae_dtype)
315
+
316
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
317
+ if args.full_fp16:
318
+ train_util.patch_accelerator_for_fp16_training(accelerator)
319
+
320
+ # resumeする
321
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
322
+
323
+ # epoch数を計算する
324
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
325
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
326
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
327
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
328
+
329
+ # 学習する
330
+ # TODO: find a way to handle total batch size when there are multiple datasets
331
+ accelerator.print("running training / 学習開始")
332
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
333
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
334
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
335
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
336
+ accelerator.print(
337
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
338
+ )
339
+ # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
340
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
341
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
342
+
343
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
344
+ global_step = 0
345
+
346
+ noise_scheduler = DDPMScheduler(
347
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
348
+ )
349
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
350
+ if args.zero_terminal_snr:
351
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
352
+
353
+ if accelerator.is_main_process:
354
+ init_kwargs = {}
355
+ if args.wandb_run_name:
356
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
357
+ if args.log_tracker_config is not None:
358
+ init_kwargs = toml.load(args.log_tracker_config)
359
+ accelerator.init_trackers(
360
+ "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
361
+ )
362
+
363
+ loss_recorder = train_util.LossRecorder()
364
+ del train_dataset_group
365
+
366
+ # function for saving/removing
367
+ def save_model(
368
+ ckpt_name,
369
+ unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
370
+ steps,
371
+ epoch_no,
372
+ force_sync_upload=False,
373
+ ):
374
+ os.makedirs(args.output_dir, exist_ok=True)
375
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
376
+
377
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
378
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
379
+ sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
380
+
381
+ unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
382
+ if args.huggingface_repo_id is not None:
383
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
384
+
385
+ def remove_model(old_ckpt_name):
386
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
387
+ if os.path.exists(old_ckpt_file):
388
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
389
+ os.remove(old_ckpt_file)
390
+
391
+ # training loop
392
+ for epoch in range(num_train_epochs):
393
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
394
+ current_epoch.value = epoch + 1
395
+
396
+ for step, batch in enumerate(train_dataloader):
397
+ current_step.value = global_step
398
+ with accelerator.accumulate(unet):
399
+ with torch.no_grad():
400
+ if "latents" in batch and batch["latents"] is not None:
401
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
402
+ else:
403
+ # latentに変換
404
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
405
+
406
+ # NaNが含まれていれば警告を表示し0に置き換える
407
+ if torch.any(torch.isnan(latents)):
408
+ accelerator.print("NaN found in latents, replacing with zeros")
409
+ latents = torch.nan_to_num(latents, 0, out=latents)
410
+ latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
411
+
412
+ if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
413
+ input_ids1 = batch["input_ids"]
414
+ input_ids2 = batch["input_ids2"]
415
+ with torch.no_grad():
416
+ # Get the text embedding for conditioning
417
+ input_ids1 = input_ids1.to(accelerator.device)
418
+ input_ids2 = input_ids2.to(accelerator.device)
419
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
420
+ args.max_token_length,
421
+ input_ids1,
422
+ input_ids2,
423
+ tokenizer1,
424
+ tokenizer2,
425
+ text_encoder1,
426
+ text_encoder2,
427
+ None if not args.full_fp16 else weight_dtype,
428
+ )
429
+ else:
430
+ encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
431
+ encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
432
+ pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
433
+
434
+ # get size embeddings
435
+ orig_size = batch["original_sizes_hw"]
436
+ crop_size = batch["crop_top_lefts"]
437
+ target_size = batch["target_sizes_hw"]
438
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
439
+
440
+ # concat embeddings
441
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
442
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
443
+
444
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
445
+ # with noise offset and/or multires noise if specified
446
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
447
+ args, noise_scheduler, latents
448
+ )
449
+
450
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
451
+
452
+ controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
453
+
454
+ with accelerator.autocast():
455
+ # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
456
+ # 内部でcond_embに変換される / it will be converted to cond_emb inside
457
+
458
+ # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
459
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
460
+
461
+ if args.v_parameterization:
462
+ # v-parameterization training
463
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
464
+ else:
465
+ target = noise
466
+
467
+ loss = train_util.conditional_loss(
468
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
469
+ )
470
+ loss = loss.mean([1, 2, 3])
471
+
472
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
473
+ loss = loss * loss_weights
474
+
475
+ if args.min_snr_gamma:
476
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
477
+ if args.scale_v_pred_loss_like_noise_pred:
478
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
479
+ if args.v_pred_like_loss:
480
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
481
+ if args.debiased_estimation_loss:
482
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
483
+
484
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
485
+
486
+ accelerator.backward(loss)
487
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
488
+ params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
489
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
490
+
491
+ optimizer.step()
492
+ lr_scheduler.step()
493
+ optimizer.zero_grad(set_to_none=True)
494
+
495
+ # Checks if the accelerator has performed an optimization step behind the scenes
496
+ if accelerator.sync_gradients:
497
+ progress_bar.update(1)
498
+ global_step += 1
499
+
500
+ # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
501
+
502
+ # 指定ステップごとにモデルを保存
503
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
504
+ accelerator.wait_for_everyone()
505
+ if accelerator.is_main_process:
506
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
507
+ save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
508
+
509
+ if args.save_state:
510
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
511
+
512
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
513
+ if remove_step_no is not None:
514
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
515
+ remove_model(remove_ckpt_name)
516
+
517
+ current_loss = loss.detach().item()
518
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
519
+ avr_loss: float = loss_recorder.moving_average
520
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
521
+ progress_bar.set_postfix(**logs)
522
+
523
+ if args.logging_dir is not None:
524
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
525
+ accelerator.log(logs, step=global_step)
526
+
527
+ if global_step >= args.max_train_steps:
528
+ break
529
+
530
+ if args.logging_dir is not None:
531
+ logs = {"loss/epoch": loss_recorder.moving_average}
532
+ accelerator.log(logs, step=epoch + 1)
533
+
534
+ accelerator.wait_for_everyone()
535
+
536
+ # 指定エポックごとにモデルを保存
537
+ if args.save_every_n_epochs is not None:
538
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
539
+ if is_main_process and saving:
540
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
541
+ save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
542
+
543
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
544
+ if remove_epoch_no is not None:
545
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
546
+ remove_model(remove_ckpt_name)
547
+
548
+ if args.save_state:
549
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
550
+
551
+ # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
552
+
553
+ # end of epoch
554
+
555
+ if is_main_process:
556
+ unet = accelerator.unwrap_model(unet)
557
+
558
+ accelerator.end_training()
559
+
560
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
561
+ train_util.save_state_on_train_end(args, accelerator)
562
+
563
+ if is_main_process:
564
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
565
+ save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
566
+
567
+ logger.info("model saved.")
568
+
569
+
570
+ def setup_parser() -> argparse.ArgumentParser:
571
+ parser = argparse.ArgumentParser()
572
+
573
+ add_logging_arguments(parser)
574
+ train_util.add_sd_models_arguments(parser)
575
+ train_util.add_dataset_arguments(parser, False, True, True)
576
+ train_util.add_training_arguments(parser, False)
577
+ deepspeed_utils.add_deepspeed_arguments(parser)
578
+ train_util.add_optimizer_arguments(parser)
579
+ config_util.add_config_arguments(parser)
580
+ custom_train_functions.add_custom_train_arguments(parser)
581
+ sdxl_train_util.add_sdxl_training_arguments(parser)
582
+
583
+ parser.add_argument(
584
+ "--save_model_as",
585
+ type=str,
586
+ default="safetensors",
587
+ choices=[None, "ckpt", "pt", "safetensors"],
588
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
589
+ )
590
+ parser.add_argument(
591
+ "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
592
+ )
593
+ parser.add_argument(
594
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
595
+ )
596
+ parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
597
+ parser.add_argument(
598
+ "--network_dropout",
599
+ type=float,
600
+ default=None,
601
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
602
+ )
603
+ parser.add_argument(
604
+ "--conditioning_data_dir",
605
+ type=str,
606
+ default=None,
607
+ help="conditioning data directory / 条件付けデータのディレクトリ",
608
+ )
609
+ parser.add_argument(
610
+ "--no_half_vae",
611
+ action="store_true",
612
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
613
+ )
614
+ return parser
615
+
616
+
617
+ if __name__ == "__main__":
618
+ # sdxl_original_unet.USE_REENTRANT = False
619
+
620
+ parser = setup_parser()
621
+
622
+ args = parser.parse_args()
623
+ train_util.verify_command_line_training_args(args)
624
+ args = train_util.read_config_from_file(args, parser)
625
+
626
+ train(args)
sdxl_train_control_net_lllite_old.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import math
4
+ import os
5
+ import random
6
+ import time
7
+ from multiprocessing import Value
8
+ from types import SimpleNamespace
9
+ import toml
10
+
11
+ from tqdm import tqdm
12
+
13
+ import torch
14
+ from library.device_utils import init_ipex, clean_memory_on_device
15
+ init_ipex()
16
+
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+ from accelerate.utils import set_seed
19
+ from diffusers import DDPMScheduler, ControlNetModel
20
+ from safetensors.torch import load_file
21
+ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
22
+
23
+ import library.model_util as model_util
24
+ import library.train_util as train_util
25
+ import library.config_util as config_util
26
+ from library.config_util import (
27
+ ConfigSanitizer,
28
+ BlueprintGenerator,
29
+ )
30
+ import library.huggingface_util as huggingface_util
31
+ import library.custom_train_functions as custom_train_functions
32
+ from library.custom_train_functions import (
33
+ add_v_prediction_like_loss,
34
+ apply_snr_weight,
35
+ prepare_scheduler_for_custom_training,
36
+ pyramid_noise_like,
37
+ apply_noise_offset,
38
+ scale_v_prediction_loss_like_noise_prediction,
39
+ apply_debiased_estimation,
40
+ )
41
+ import networks.control_net_lllite as control_net_lllite
42
+ from library.utils import setup_logging, add_logging_arguments
43
+
44
+ setup_logging()
45
+ import logging
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ # TODO 他のスクリプトと共通化する
51
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
52
+ logs = {
53
+ "loss/current": current_loss,
54
+ "loss/average": avr_loss,
55
+ "lr": lr_scheduler.get_last_lr()[0],
56
+ }
57
+
58
+ if args.optimizer_type.lower().startswith("DAdapt".lower()):
59
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
60
+
61
+ return logs
62
+
63
+
64
+ def train(args):
65
+ train_util.verify_training_args(args)
66
+ train_util.prepare_dataset_args(args, True)
67
+ sdxl_train_util.verify_sdxl_training_args(args)
68
+ setup_logging(args, reset=True)
69
+
70
+ cache_latents = args.cache_latents
71
+ use_user_config = args.dataset_config is not None
72
+
73
+ if args.seed is None:
74
+ args.seed = random.randint(0, 2**32)
75
+ set_seed(args.seed)
76
+
77
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
78
+
79
+ # データセットを準備する
80
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
81
+ if use_user_config:
82
+ logger.info(f"Load dataset config from {args.dataset_config}")
83
+ user_config = config_util.load_user_config(args.dataset_config)
84
+ ignored = ["train_data_dir", "conditioning_data_dir"]
85
+ if any(getattr(args, attr) is not None for attr in ignored):
86
+ logger.warning(
87
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
88
+ ", ".join(ignored)
89
+ )
90
+ )
91
+ else:
92
+ user_config = {
93
+ "datasets": [
94
+ {
95
+ "subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
96
+ args.train_data_dir,
97
+ args.conditioning_data_dir,
98
+ args.caption_extension,
99
+ )
100
+ }
101
+ ]
102
+ }
103
+
104
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
105
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
106
+
107
+ current_epoch = Value("i", 0)
108
+ current_step = Value("i", 0)
109
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
110
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
111
+
112
+ train_dataset_group.verify_bucket_reso_steps(32)
113
+
114
+ if args.debug_dataset:
115
+ train_util.debug_dataset(train_dataset_group)
116
+ return
117
+ if len(train_dataset_group) == 0:
118
+ logger.error(
119
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
120
+ )
121
+ return
122
+
123
+ if cache_latents:
124
+ assert (
125
+ train_dataset_group.is_latent_cacheable()
126
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
127
+ else:
128
+ logger.warning(
129
+ "WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
130
+ )
131
+
132
+ if args.cache_text_encoder_outputs:
133
+ assert (
134
+ train_dataset_group.is_text_encoder_output_cacheable()
135
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
136
+
137
+ # acceleratorを準備する
138
+ logger.info("prepare accelerator")
139
+ accelerator = train_util.prepare_accelerator(args)
140
+ is_main_process = accelerator.is_main_process
141
+
142
+ # mixed precisionに対応した型を用意しておき適宜castする
143
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
144
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
145
+
146
+ # モデルを読み込む
147
+ (
148
+ load_stable_diffusion_format,
149
+ text_encoder1,
150
+ text_encoder2,
151
+ vae,
152
+ unet,
153
+ logit_scale,
154
+ ckpt_info,
155
+ ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
156
+
157
+ # モデルに xformers とか memory efficient attention を組み込む
158
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
159
+
160
+ # 学習を準備する
161
+ if cache_latents:
162
+ vae.to(accelerator.device, dtype=vae_dtype)
163
+ vae.requires_grad_(False)
164
+ vae.eval()
165
+ with torch.no_grad():
166
+ train_dataset_group.cache_latents(
167
+ vae,
168
+ args.vae_batch_size,
169
+ args.cache_latents_to_disk,
170
+ accelerator.is_main_process,
171
+ )
172
+ vae.to("cpu")
173
+ clean_memory_on_device(accelerator.device)
174
+
175
+ accelerator.wait_for_everyone()
176
+
177
+ # TextEncoderの出力をキャッシュする
178
+ if args.cache_text_encoder_outputs:
179
+ # Text Encodes are eval and no grad
180
+ with torch.no_grad():
181
+ train_dataset_group.cache_text_encoder_outputs(
182
+ (tokenizer1, tokenizer2),
183
+ (text_encoder1, text_encoder2),
184
+ accelerator.device,
185
+ None,
186
+ args.cache_text_encoder_outputs_to_disk,
187
+ accelerator.is_main_process,
188
+ )
189
+ accelerator.wait_for_everyone()
190
+
191
+ # prepare ControlNet
192
+ network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
193
+ network.apply_to()
194
+
195
+ if args.network_weights is not None:
196
+ info = network.load_weights(args.network_weights)
197
+ accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
198
+
199
+ if args.gradient_checkpointing:
200
+ unet.enable_gradient_checkpointing()
201
+ network.enable_gradient_checkpointing() # may have no effect
202
+
203
+ # 学習に必要なクラスを準備する
204
+ accelerator.print("prepare optimizer, data loader etc.")
205
+
206
+ trainable_params = list(network.prepare_optimizer_params())
207
+ logger.info(f"trainable params count: {len(trainable_params)}")
208
+ logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
209
+
210
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
211
+
212
+ # dataloaderを準備する
213
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
214
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
215
+
216
+ train_dataloader = torch.utils.data.DataLoader(
217
+ train_dataset_group,
218
+ batch_size=1,
219
+ shuffle=True,
220
+ collate_fn=collator,
221
+ num_workers=n_workers,
222
+ persistent_workers=args.persistent_data_loader_workers,
223
+ )
224
+
225
+ # 学習ステップ数を計算する
226
+ if args.max_train_epochs is not None:
227
+ args.max_train_steps = args.max_train_epochs * math.ceil(
228
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
229
+ )
230
+ accelerator.print(
231
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
232
+ )
233
+
234
+ # データセット側にも学習ステップを送信
235
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
236
+
237
+ # lr schedulerを用意する
238
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
239
+
240
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
241
+ if args.full_fp16:
242
+ assert (
243
+ args.mixed_precision == "fp16"
244
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
245
+ accelerator.print("enable full fp16 training.")
246
+ unet.to(weight_dtype)
247
+ network.to(weight_dtype)
248
+ elif args.full_bf16:
249
+ assert (
250
+ args.mixed_precision == "bf16"
251
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
252
+ accelerator.print("enable full bf16 training.")
253
+ unet.to(weight_dtype)
254
+ network.to(weight_dtype)
255
+
256
+ # acceleratorがなんかよろしくやってくれるらしい
257
+ unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
258
+ unet, network, optimizer, train_dataloader, lr_scheduler
259
+ )
260
+ network: control_net_lllite.ControlNetLLLite
261
+
262
+ if args.gradient_checkpointing:
263
+ unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
264
+ else:
265
+ unet.eval()
266
+
267
+ network.prepare_grad_etc()
268
+
269
+ # TextEncoderの出力をキャッシュするときにはCPUへ移動する
270
+ if args.cache_text_encoder_outputs:
271
+ # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
272
+ text_encoder1.to("cpu", dtype=torch.float32)
273
+ text_encoder2.to("cpu", dtype=torch.float32)
274
+ clean_memory_on_device(accelerator.device)
275
+ else:
276
+ # make sure Text Encoders are on GPU
277
+ text_encoder1.to(accelerator.device)
278
+ text_encoder2.to(accelerator.device)
279
+
280
+ if not cache_latents:
281
+ vae.requires_grad_(False)
282
+ vae.eval()
283
+ vae.to(accelerator.device, dtype=vae_dtype)
284
+
285
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
286
+ if args.full_fp16:
287
+ train_util.patch_accelerator_for_fp16_training(accelerator)
288
+
289
+ # resumeする
290
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
291
+
292
+ # epoch数を計算する
293
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
294
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
295
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
296
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
297
+
298
+ # 学習する
299
+ # TODO: find a way to handle total batch size when there are multiple datasets
300
+ accelerator.print("running training / 学習開始")
301
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
302
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
303
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
304
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
305
+ accelerator.print(
306
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
307
+ )
308
+ # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
309
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
310
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
311
+
312
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
313
+ global_step = 0
314
+
315
+ noise_scheduler = DDPMScheduler(
316
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
317
+ )
318
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
319
+ if args.zero_terminal_snr:
320
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
321
+
322
+ if accelerator.is_main_process:
323
+ init_kwargs = {}
324
+ if args.log_tracker_config is not None:
325
+ init_kwargs = toml.load(args.log_tracker_config)
326
+ accelerator.init_trackers(
327
+ "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
328
+ )
329
+
330
+ loss_recorder = train_util.LossRecorder()
331
+ del train_dataset_group
332
+
333
+ # function for saving/removing
334
+ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
335
+ os.makedirs(args.output_dir, exist_ok=True)
336
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
337
+
338
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
339
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
340
+ sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
341
+
342
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
343
+ if args.huggingface_repo_id is not None:
344
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
345
+
346
+ def remove_model(old_ckpt_name):
347
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
348
+ if os.path.exists(old_ckpt_file):
349
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
350
+ os.remove(old_ckpt_file)
351
+
352
+ # training loop
353
+ for epoch in range(num_train_epochs):
354
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
355
+ current_epoch.value = epoch + 1
356
+
357
+ network.on_epoch_start() # train()
358
+
359
+ for step, batch in enumerate(train_dataloader):
360
+ current_step.value = global_step
361
+ with accelerator.accumulate(network):
362
+ with torch.no_grad():
363
+ if "latents" in batch and batch["latents"] is not None:
364
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
365
+ else:
366
+ # latentに変換
367
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
368
+
369
+ # NaNが含まれていれば警告を表示し0に置き換える
370
+ if torch.any(torch.isnan(latents)):
371
+ accelerator.print("NaN found in latents, replacing with zeros")
372
+ latents = torch.nan_to_num(latents, 0, out=latents)
373
+ latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
374
+
375
+ if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
376
+ input_ids1 = batch["input_ids"]
377
+ input_ids2 = batch["input_ids2"]
378
+ with torch.no_grad():
379
+ # Get the text embedding for conditioning
380
+ input_ids1 = input_ids1.to(accelerator.device)
381
+ input_ids2 = input_ids2.to(accelerator.device)
382
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
383
+ args.max_token_length,
384
+ input_ids1,
385
+ input_ids2,
386
+ tokenizer1,
387
+ tokenizer2,
388
+ text_encoder1,
389
+ text_encoder2,
390
+ None if not args.full_fp16 else weight_dtype,
391
+ )
392
+ else:
393
+ encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
394
+ encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
395
+ pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
396
+
397
+ # get size embeddings
398
+ orig_size = batch["original_sizes_hw"]
399
+ crop_size = batch["crop_top_lefts"]
400
+ target_size = batch["target_sizes_hw"]
401
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
402
+
403
+ # concat embeddings
404
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
405
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
406
+
407
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
408
+ # with noise offset and/or multires noise if specified
409
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
410
+
411
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
412
+
413
+ controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
414
+
415
+ with accelerator.autocast():
416
+ # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
417
+ # 内部でcond_embに変換される / it will be converted to cond_emb inside
418
+ network.set_cond_image(controlnet_image)
419
+
420
+ # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values
421
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
422
+
423
+ if args.v_parameterization:
424
+ # v-parameterization training
425
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
426
+ else:
427
+ target = noise
428
+
429
+ loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
430
+ loss = loss.mean([1, 2, 3])
431
+
432
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
433
+ loss = loss * loss_weights
434
+
435
+ if args.min_snr_gamma:
436
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
437
+ if args.scale_v_pred_loss_like_noise_pred:
438
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
439
+ if args.v_pred_like_loss:
440
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
441
+ if args.debiased_estimation_loss:
442
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
443
+
444
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
445
+
446
+ accelerator.backward(loss)
447
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
448
+ params_to_clip = network.get_trainable_params()
449
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
450
+
451
+ optimizer.step()
452
+ lr_scheduler.step()
453
+ optimizer.zero_grad(set_to_none=True)
454
+
455
+ # Checks if the accelerator has performed an optimization step behind the scenes
456
+ if accelerator.sync_gradients:
457
+ progress_bar.update(1)
458
+ global_step += 1
459
+
460
+ # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
461
+
462
+ # 指定ステップごとにモデルを保存
463
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
464
+ accelerator.wait_for_everyone()
465
+ if accelerator.is_main_process:
466
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
467
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
468
+
469
+ if args.save_state:
470
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
471
+
472
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
473
+ if remove_step_no is not None:
474
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
475
+ remove_model(remove_ckpt_name)
476
+
477
+ current_loss = loss.detach().item()
478
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
479
+ avr_loss: float = loss_recorder.moving_average
480
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
481
+ progress_bar.set_postfix(**logs)
482
+
483
+ if args.logging_dir is not None:
484
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
485
+ accelerator.log(logs, step=global_step)
486
+
487
+ if global_step >= args.max_train_steps:
488
+ break
489
+
490
+ if args.logging_dir is not None:
491
+ logs = {"loss/epoch": loss_recorder.moving_average}
492
+ accelerator.log(logs, step=epoch + 1)
493
+
494
+ accelerator.wait_for_everyone()
495
+
496
+ # 指定エポックごとにモデルを保存
497
+ if args.save_every_n_epochs is not None:
498
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
499
+ if is_main_process and saving:
500
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
501
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
502
+
503
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
504
+ if remove_epoch_no is not None:
505
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
506
+ remove_model(remove_ckpt_name)
507
+
508
+ if args.save_state:
509
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
510
+
511
+ # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
512
+
513
+ # end of epoch
514
+
515
+ if is_main_process:
516
+ network = accelerator.unwrap_model(network)
517
+
518
+ accelerator.end_training()
519
+
520
+ if is_main_process and args.save_state:
521
+ train_util.save_state_on_train_end(args, accelerator)
522
+
523
+ if is_main_process:
524
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
525
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
526
+
527
+ logger.info("model saved.")
528
+
529
+
530
+ def setup_parser() -> argparse.ArgumentParser:
531
+ parser = argparse.ArgumentParser()
532
+
533
+ add_logging_arguments(parser)
534
+ train_util.add_sd_models_arguments(parser)
535
+ train_util.add_dataset_arguments(parser, False, True, True)
536
+ train_util.add_training_arguments(parser, False)
537
+ deepspeed_utils.add_deepspeed_arguments(parser)
538
+ train_util.add_optimizer_arguments(parser)
539
+ config_util.add_config_arguments(parser)
540
+ custom_train_functions.add_custom_train_arguments(parser)
541
+ sdxl_train_util.add_sdxl_training_arguments(parser)
542
+
543
+ parser.add_argument(
544
+ "--save_model_as",
545
+ type=str,
546
+ default="safetensors",
547
+ choices=[None, "ckpt", "pt", "safetensors"],
548
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
549
+ )
550
+ parser.add_argument(
551
+ "--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
552
+ )
553
+ parser.add_argument(
554
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
555
+ )
556
+ parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
557
+ parser.add_argument(
558
+ "--network_dropout",
559
+ type=float,
560
+ default=None,
561
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
562
+ )
563
+ parser.add_argument(
564
+ "--conditioning_data_dir",
565
+ type=str,
566
+ default=None,
567
+ help="conditioning data directory / 条件付けデータのディレクトリ",
568
+ )
569
+ parser.add_argument(
570
+ "--no_half_vae",
571
+ action="store_true",
572
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
573
+ )
574
+ return parser
575
+
576
+
577
+ if __name__ == "__main__":
578
+ # sdxl_original_unet.USE_REENTRANT = False
579
+
580
+ parser = setup_parser()
581
+
582
+ args = parser.parse_args()
583
+ train_util.verify_command_line_training_args(args)
584
+ args = train_util.read_config_from_file(args, parser)
585
+
586
+ train(args)
sdxl_train_network.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from library.device_utils import init_ipex, clean_memory_on_device
5
+ init_ipex()
6
+
7
+ from library import sdxl_model_util, sdxl_train_util, train_util
8
+ import train_network
9
+ from library.utils import setup_logging
10
+ setup_logging()
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class SdxlNetworkTrainer(train_network.NetworkTrainer):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
18
+ self.is_sdxl = True
19
+
20
+ def assert_extra_args(self, args, train_dataset_group):
21
+ sdxl_train_util.verify_sdxl_training_args(args)
22
+
23
+ if args.cache_text_encoder_outputs:
24
+ assert (
25
+ train_dataset_group.is_text_encoder_output_cacheable()
26
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
27
+
28
+ assert (
29
+ args.network_train_unet_only or not args.cache_text_encoder_outputs
30
+ ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
31
+
32
+ train_dataset_group.verify_bucket_reso_steps(32)
33
+
34
+ def load_target_model(self, args, weight_dtype, accelerator):
35
+ (
36
+ load_stable_diffusion_format,
37
+ text_encoder1,
38
+ text_encoder2,
39
+ vae,
40
+ unet,
41
+ logit_scale,
42
+ ckpt_info,
43
+ ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
44
+
45
+ self.load_stable_diffusion_format = load_stable_diffusion_format
46
+ self.logit_scale = logit_scale
47
+ self.ckpt_info = ckpt_info
48
+
49
+ return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
50
+
51
+ def load_tokenizer(self, args):
52
+ tokenizer = sdxl_train_util.load_tokenizers(args)
53
+ return tokenizer
54
+
55
+ def is_text_encoder_outputs_cached(self, args):
56
+ return args.cache_text_encoder_outputs
57
+
58
+ def cache_text_encoder_outputs_if_needed(
59
+ self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
60
+ ):
61
+ if args.cache_text_encoder_outputs:
62
+ if not args.lowram:
63
+ # メモリ消費を減らす
64
+ logger.info("move vae and unet to cpu to save memory")
65
+ org_vae_device = vae.device
66
+ org_unet_device = unet.device
67
+ vae.to("cpu")
68
+ unet.to("cpu")
69
+ clean_memory_on_device(accelerator.device)
70
+
71
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
72
+ with accelerator.autocast():
73
+ dataset.cache_text_encoder_outputs(
74
+ tokenizers,
75
+ text_encoders,
76
+ accelerator.device,
77
+ weight_dtype,
78
+ args.cache_text_encoder_outputs_to_disk,
79
+ accelerator.is_main_process,
80
+ )
81
+
82
+ text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
83
+ text_encoders[1].to("cpu", dtype=torch.float32)
84
+ clean_memory_on_device(accelerator.device)
85
+
86
+ if not args.lowram:
87
+ logger.info("move vae and unet back to original device")
88
+ vae.to(org_vae_device)
89
+ unet.to(org_unet_device)
90
+ else:
91
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
92
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
93
+ text_encoders[1].to(accelerator.device, dtype=weight_dtype)
94
+
95
+ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
96
+ if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
97
+ input_ids1 = batch["input_ids"]
98
+ input_ids2 = batch["input_ids2"]
99
+ with torch.enable_grad():
100
+ # Get the text embedding for conditioning
101
+ # TODO support weighted captions
102
+ # if args.weighted_captions:
103
+ # encoder_hidden_states = get_weighted_text_embeddings(
104
+ # tokenizer,
105
+ # text_encoder,
106
+ # batch["captions"],
107
+ # accelerator.device,
108
+ # args.max_token_length // 75 if args.max_token_length else 1,
109
+ # clip_skip=args.clip_skip,
110
+ # )
111
+ # else:
112
+ input_ids1 = input_ids1.to(accelerator.device)
113
+ input_ids2 = input_ids2.to(accelerator.device)
114
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
115
+ args.max_token_length,
116
+ input_ids1,
117
+ input_ids2,
118
+ tokenizers[0],
119
+ tokenizers[1],
120
+ text_encoders[0],
121
+ text_encoders[1],
122
+ None if not args.full_fp16 else weight_dtype,
123
+ accelerator=accelerator,
124
+ )
125
+ else:
126
+ encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
127
+ encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
128
+ pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
129
+
130
+ # # verify that the text encoder outputs are correct
131
+ # ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
132
+ # args.max_token_length,
133
+ # batch["input_ids"].to(text_encoders[0].device),
134
+ # batch["input_ids2"].to(text_encoders[0].device),
135
+ # tokenizers[0],
136
+ # tokenizers[1],
137
+ # text_encoders[0],
138
+ # text_encoders[1],
139
+ # None if not args.full_fp16 else weight_dtype,
140
+ # )
141
+ # b_size = encoder_hidden_states1.shape[0]
142
+ # assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
143
+ # assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
144
+ # assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
145
+ # logger.info("text encoder outputs verified")
146
+
147
+ return encoder_hidden_states1, encoder_hidden_states2, pool2
148
+
149
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
150
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
151
+
152
+ # get size embeddings
153
+ orig_size = batch["original_sizes_hw"]
154
+ crop_size = batch["crop_top_lefts"]
155
+ target_size = batch["target_sizes_hw"]
156
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
157
+
158
+ # concat embeddings
159
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
160
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
161
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
162
+
163
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
164
+ return noise_pred
165
+
166
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
167
+ sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
168
+
169
+
170
+ def setup_parser() -> argparse.ArgumentParser:
171
+ parser = train_network.setup_parser()
172
+ sdxl_train_util.add_sdxl_training_arguments(parser)
173
+ return parser
174
+
175
+
176
+ if __name__ == "__main__":
177
+ parser = setup_parser()
178
+
179
+ args = parser.parse_args()
180
+ train_util.verify_command_line_training_args(args)
181
+ args = train_util.read_config_from_file(args, parser)
182
+
183
+ trainer = SdxlNetworkTrainer()
184
+ trainer.train(args)
sdxl_train_textual_inversion.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import regex
5
+
6
+ import torch
7
+ from library.device_utils import init_ipex
8
+ init_ipex()
9
+
10
+ from library import sdxl_model_util, sdxl_train_util, train_util
11
+
12
+ import train_textual_inversion
13
+
14
+
15
+ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
19
+ self.is_sdxl = True
20
+
21
+ def assert_extra_args(self, args, train_dataset_group):
22
+ super().assert_extra_args(args, train_dataset_group)
23
+ sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
24
+
25
+ train_dataset_group.verify_bucket_reso_steps(32)
26
+
27
+ def load_target_model(self, args, weight_dtype, accelerator):
28
+ (
29
+ load_stable_diffusion_format,
30
+ text_encoder1,
31
+ text_encoder2,
32
+ vae,
33
+ unet,
34
+ logit_scale,
35
+ ckpt_info,
36
+ ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
37
+
38
+ self.load_stable_diffusion_format = load_stable_diffusion_format
39
+ self.logit_scale = logit_scale
40
+ self.ckpt_info = ckpt_info
41
+
42
+ return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
43
+
44
+ def load_tokenizer(self, args):
45
+ tokenizer = sdxl_train_util.load_tokenizers(args)
46
+ return tokenizer
47
+
48
+ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
49
+ input_ids1 = batch["input_ids"]
50
+ input_ids2 = batch["input_ids2"]
51
+ with torch.enable_grad():
52
+ input_ids1 = input_ids1.to(accelerator.device)
53
+ input_ids2 = input_ids2.to(accelerator.device)
54
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
55
+ args.max_token_length,
56
+ input_ids1,
57
+ input_ids2,
58
+ tokenizers[0],
59
+ tokenizers[1],
60
+ text_encoders[0],
61
+ text_encoders[1],
62
+ None if not args.full_fp16 else weight_dtype,
63
+ accelerator=accelerator,
64
+ )
65
+ return encoder_hidden_states1, encoder_hidden_states2, pool2
66
+
67
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
68
+ noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
69
+
70
+ # get size embeddings
71
+ orig_size = batch["original_sizes_hw"]
72
+ crop_size = batch["crop_top_lefts"]
73
+ target_size = batch["target_sizes_hw"]
74
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
75
+
76
+ # concat embeddings
77
+ encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
78
+ vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
79
+ text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
80
+
81
+ noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
82
+ return noise_pred
83
+
84
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
85
+ sdxl_train_util.sample_images(
86
+ accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
87
+ )
88
+
89
+ def save_weights(self, file, updated_embs, save_dtype, metadata):
90
+ state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
91
+
92
+ if save_dtype is not None:
93
+ for key in list(state_dict.keys()):
94
+ v = state_dict[key]
95
+ v = v.detach().clone().to("cpu").to(save_dtype)
96
+ state_dict[key] = v
97
+
98
+ if os.path.splitext(file)[1] == ".safetensors":
99
+ from safetensors.torch import save_file
100
+
101
+ save_file(state_dict, file, metadata)
102
+ else:
103
+ torch.save(state_dict, file)
104
+
105
+ def load_weights(self, file):
106
+ if os.path.splitext(file)[1] == ".safetensors":
107
+ from safetensors.torch import load_file
108
+
109
+ data = load_file(file)
110
+ else:
111
+ data = torch.load(file, map_location="cpu")
112
+
113
+ emb_l = data.get("clip_l", None) # ViT-L text encoder 1
114
+ emb_g = data.get("clip_g", None) # BiG-G text encoder 2
115
+
116
+ assert (
117
+ emb_l is not None or emb_g is not None
118
+ ), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
119
+
120
+ return [emb_l, emb_g]
121
+
122
+
123
+ def setup_parser() -> argparse.ArgumentParser:
124
+ parser = train_textual_inversion.setup_parser()
125
+ # don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
126
+ # sdxl_train_util.add_sdxl_training_arguments(parser)
127
+ return parser
128
+
129
+
130
+ if __name__ == "__main__":
131
+ parser = setup_parser()
132
+
133
+ args = parser.parse_args()
134
+ train_util.verify_command_line_training_args(args)
135
+ args = train_util.read_config_from_file(args, parser)
136
+
137
+ trainer = SdxlTextualInversionTrainer()
138
+ trainer.train(args)
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(name = "library", packages = find_packages())
train_controlnet.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import math
4
+ import os
5
+ import random
6
+ import time
7
+ from multiprocessing import Value
8
+
9
+ # from omegaconf import OmegaConf
10
+ import toml
11
+
12
+ from tqdm import tqdm
13
+
14
+ import torch
15
+ from library import deepspeed_utils
16
+ from library.device_utils import init_ipex, clean_memory_on_device
17
+
18
+ init_ipex()
19
+
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from accelerate.utils import set_seed
22
+ from diffusers import DDPMScheduler, ControlNetModel
23
+ from safetensors.torch import load_file
24
+
25
+ import library.model_util as model_util
26
+ import library.train_util as train_util
27
+ import library.config_util as config_util
28
+ from library.config_util import (
29
+ ConfigSanitizer,
30
+ BlueprintGenerator,
31
+ )
32
+ import library.huggingface_util as huggingface_util
33
+ import library.custom_train_functions as custom_train_functions
34
+ from library.custom_train_functions import (
35
+ apply_snr_weight,
36
+ pyramid_noise_like,
37
+ apply_noise_offset,
38
+ )
39
+ from library.utils import setup_logging, add_logging_arguments
40
+
41
+ setup_logging()
42
+ import logging
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ # TODO 他のスクリプトと共通化する
48
+ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
49
+ logs = {
50
+ "loss/current": current_loss,
51
+ "loss/average": avr_loss,
52
+ "lr": lr_scheduler.get_last_lr()[0],
53
+ }
54
+
55
+ if args.optimizer_type.lower().startswith("DAdapt".lower()):
56
+ logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
57
+
58
+ return logs
59
+
60
+
61
+ def train(args):
62
+ # session_id = random.randint(0, 2**32)
63
+ # training_started_at = time.time()
64
+ train_util.verify_training_args(args)
65
+ train_util.prepare_dataset_args(args, True)
66
+ setup_logging(args, reset=True)
67
+
68
+ cache_latents = args.cache_latents
69
+ use_user_config = args.dataset_config is not None
70
+
71
+ if args.seed is None:
72
+ args.seed = random.randint(0, 2**32)
73
+ set_seed(args.seed)
74
+
75
+ tokenizer = train_util.load_tokenizer(args)
76
+
77
+ # データセットを準備する
78
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
79
+ if use_user_config:
80
+ logger.info(f"Load dataset config from {args.dataset_config}")
81
+ user_config = config_util.load_user_config(args.dataset_config)
82
+ ignored = ["train_data_dir", "conditioning_data_dir"]
83
+ if any(getattr(args, attr) is not None for attr in ignored):
84
+ logger.warning(
85
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
86
+ ", ".join(ignored)
87
+ )
88
+ )
89
+ else:
90
+ user_config = {
91
+ "datasets": [
92
+ {
93
+ "subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
94
+ args.train_data_dir,
95
+ args.conditioning_data_dir,
96
+ args.caption_extension,
97
+ )
98
+ }
99
+ ]
100
+ }
101
+
102
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
103
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
104
+
105
+ current_epoch = Value("i", 0)
106
+ current_step = Value("i", 0)
107
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
108
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
109
+
110
+ train_dataset_group.verify_bucket_reso_steps(64)
111
+
112
+ if args.debug_dataset:
113
+ train_util.debug_dataset(train_dataset_group)
114
+ return
115
+ if len(train_dataset_group) == 0:
116
+ logger.error(
117
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
118
+ )
119
+ return
120
+
121
+ if cache_latents:
122
+ assert (
123
+ train_dataset_group.is_latent_cacheable()
124
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
125
+
126
+ # acceleratorを準備する
127
+ logger.info("prepare accelerator")
128
+ accelerator = train_util.prepare_accelerator(args)
129
+ is_main_process = accelerator.is_main_process
130
+
131
+ # mixed precisionに対応した型を用意しておき適宜castする
132
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
133
+
134
+ # モデルを読み込む
135
+ text_encoder, vae, unet, _ = train_util.load_target_model(
136
+ args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
137
+ )
138
+
139
+ # DiffusersのControlNetが使用するデータを準備する
140
+ if args.v2:
141
+ unet.config = {
142
+ "act_fn": "silu",
143
+ "attention_head_dim": [5, 10, 20, 20],
144
+ "block_out_channels": [320, 640, 1280, 1280],
145
+ "center_input_sample": False,
146
+ "cross_attention_dim": 1024,
147
+ "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
148
+ "downsample_padding": 1,
149
+ "dual_cross_attention": False,
150
+ "flip_sin_to_cos": True,
151
+ "freq_shift": 0,
152
+ "in_channels": 4,
153
+ "layers_per_block": 2,
154
+ "mid_block_scale_factor": 1,
155
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
156
+ "norm_eps": 1e-05,
157
+ "norm_num_groups": 32,
158
+ "num_attention_heads": [5, 10, 20, 20],
159
+ "num_class_embeds": None,
160
+ "only_cross_attention": False,
161
+ "out_channels": 4,
162
+ "sample_size": 96,
163
+ "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
164
+ "use_linear_projection": True,
165
+ "upcast_attention": True,
166
+ "only_cross_attention": False,
167
+ "downsample_padding": 1,
168
+ "use_linear_projection": True,
169
+ "class_embed_type": None,
170
+ "num_class_embeds": None,
171
+ "resnet_time_scale_shift": "default",
172
+ "projection_class_embeddings_input_dim": None,
173
+ }
174
+ else:
175
+ unet.config = {
176
+ "act_fn": "silu",
177
+ "attention_head_dim": 8,
178
+ "block_out_channels": [320, 640, 1280, 1280],
179
+ "center_input_sample": False,
180
+ "cross_attention_dim": 768,
181
+ "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
182
+ "downsample_padding": 1,
183
+ "flip_sin_to_cos": True,
184
+ "freq_shift": 0,
185
+ "in_channels": 4,
186
+ "layers_per_block": 2,
187
+ "mid_block_scale_factor": 1,
188
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
189
+ "norm_eps": 1e-05,
190
+ "norm_num_groups": 32,
191
+ "num_attention_heads": 8,
192
+ "out_channels": 4,
193
+ "sample_size": 64,
194
+ "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
195
+ "only_cross_attention": False,
196
+ "downsample_padding": 1,
197
+ "use_linear_projection": False,
198
+ "class_embed_type": None,
199
+ "num_class_embeds": None,
200
+ "upcast_attention": False,
201
+ "resnet_time_scale_shift": "default",
202
+ "projection_class_embeddings_input_dim": None,
203
+ }
204
+ # unet.config = OmegaConf.create(unet.config)
205
+
206
+ # make unet.config iterable and accessible by attribute
207
+ class CustomConfig:
208
+ def __init__(self, **kwargs):
209
+ self.__dict__.update(kwargs)
210
+
211
+ def __getattr__(self, name):
212
+ if name in self.__dict__:
213
+ return self.__dict__[name]
214
+ else:
215
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
216
+
217
+ def __contains__(self, name):
218
+ return name in self.__dict__
219
+
220
+ unet.config = CustomConfig(**unet.config)
221
+
222
+ controlnet = ControlNetModel.from_unet(unet)
223
+
224
+ if args.controlnet_model_name_or_path:
225
+ filename = args.controlnet_model_name_or_path
226
+ if os.path.isfile(filename):
227
+ if os.path.splitext(filename)[1] == ".safetensors":
228
+ state_dict = load_file(filename)
229
+ else:
230
+ state_dict = torch.load(filename)
231
+ state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
232
+ controlnet.load_state_dict(state_dict)
233
+ elif os.path.isdir(filename):
234
+ controlnet = ControlNetModel.from_pretrained(filename)
235
+
236
+ # モデルに xformers とか memory efficient attention を組み込む
237
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
238
+
239
+ # 学習を準備する
240
+ if cache_latents:
241
+ vae.to(accelerator.device, dtype=weight_dtype)
242
+ vae.requires_grad_(False)
243
+ vae.eval()
244
+ with torch.no_grad():
245
+ train_dataset_group.cache_latents(
246
+ vae,
247
+ args.vae_batch_size,
248
+ args.cache_latents_to_disk,
249
+ accelerator.is_main_process,
250
+ )
251
+ vae.to("cpu")
252
+ clean_memory_on_device(accelerator.device)
253
+
254
+ accelerator.wait_for_everyone()
255
+
256
+ if args.gradient_checkpointing:
257
+ controlnet.enable_gradient_checkpointing()
258
+
259
+ # 学習に必要なクラスを準備する
260
+ accelerator.print("prepare optimizer, data loader etc.")
261
+
262
+ trainable_params = list(controlnet.parameters())
263
+
264
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
265
+
266
+ # dataloaderを準備する
267
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
268
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
269
+
270
+ train_dataloader = torch.utils.data.DataLoader(
271
+ train_dataset_group,
272
+ batch_size=1,
273
+ shuffle=True,
274
+ collate_fn=collator,
275
+ num_workers=n_workers,
276
+ persistent_workers=args.persistent_data_loader_workers,
277
+ )
278
+
279
+ # 学習ステップ数を計算する
280
+ if args.max_train_epochs is not None:
281
+ args.max_train_steps = args.max_train_epochs * math.ceil(
282
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
283
+ )
284
+ accelerator.print(
285
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
286
+ )
287
+
288
+ # データセット側にも学習ステップを送信
289
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
290
+
291
+ # lr schedulerを用意する
292
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
293
+
294
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
295
+ if args.full_fp16:
296
+ assert (
297
+ args.mixed_precision == "fp16"
298
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
299
+ accelerator.print("enable full fp16 training.")
300
+ controlnet.to(weight_dtype)
301
+
302
+ # acceleratorがなんかよろしくやってくれるらしい
303
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
304
+ controlnet, optimizer, train_dataloader, lr_scheduler
305
+ )
306
+
307
+ unet.requires_grad_(False)
308
+ text_encoder.requires_grad_(False)
309
+ unet.to(accelerator.device)
310
+ text_encoder.to(accelerator.device)
311
+
312
+ # transform DDP after prepare
313
+ controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
314
+
315
+ controlnet.train()
316
+
317
+ if not cache_latents:
318
+ vae.requires_grad_(False)
319
+ vae.eval()
320
+ vae.to(accelerator.device, dtype=weight_dtype)
321
+
322
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
323
+ if args.full_fp16:
324
+ train_util.patch_accelerator_for_fp16_training(accelerator)
325
+
326
+ # resumeする
327
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
328
+
329
+ # epoch数を計算する
330
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
331
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
332
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
333
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
334
+
335
+ # 学習する
336
+ # TODO: find a way to handle total batch size when there are multiple datasets
337
+ accelerator.print("running training / 学習開始")
338
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
339
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
340
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
341
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
342
+ accelerator.print(
343
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
344
+ )
345
+ # logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
346
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
347
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
348
+
349
+ progress_bar = tqdm(
350
+ range(args.max_train_steps),
351
+ smoothing=0,
352
+ disable=not accelerator.is_local_main_process,
353
+ desc="steps",
354
+ )
355
+ global_step = 0
356
+
357
+ noise_scheduler = DDPMScheduler(
358
+ beta_start=0.00085,
359
+ beta_end=0.012,
360
+ beta_schedule="scaled_linear",
361
+ num_train_timesteps=1000,
362
+ clip_sample=False,
363
+ )
364
+ if accelerator.is_main_process:
365
+ init_kwargs = {}
366
+ if args.wandb_run_name:
367
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
368
+ if args.log_tracker_config is not None:
369
+ init_kwargs = toml.load(args.log_tracker_config)
370
+ accelerator.init_trackers(
371
+ "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
372
+ config=train_util.get_sanitized_config_or_none(args),
373
+ init_kwargs=init_kwargs,
374
+ )
375
+
376
+ loss_recorder = train_util.LossRecorder()
377
+ del train_dataset_group
378
+
379
+ # function for saving/removing
380
+ def save_model(ckpt_name, model, force_sync_upload=False):
381
+ os.makedirs(args.output_dir, exist_ok=True)
382
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
383
+
384
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
385
+
386
+ state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
387
+
388
+ if save_dtype is not None:
389
+ for key in list(state_dict.keys()):
390
+ v = state_dict[key]
391
+ v = v.detach().clone().to("cpu").to(save_dtype)
392
+ state_dict[key] = v
393
+
394
+ if os.path.splitext(ckpt_file)[1] == ".safetensors":
395
+ from safetensors.torch import save_file
396
+
397
+ save_file(state_dict, ckpt_file)
398
+ else:
399
+ torch.save(state_dict, ckpt_file)
400
+
401
+ if args.huggingface_repo_id is not None:
402
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
403
+
404
+ def remove_model(old_ckpt_name):
405
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
406
+ if os.path.exists(old_ckpt_file):
407
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
408
+ os.remove(old_ckpt_file)
409
+
410
+ # For --sample_at_first
411
+ train_util.sample_images(
412
+ accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
413
+ )
414
+
415
+ # training loop
416
+ for epoch in range(num_train_epochs):
417
+ if is_main_process:
418
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
419
+ current_epoch.value = epoch + 1
420
+
421
+ for step, batch in enumerate(train_dataloader):
422
+ current_step.value = global_step
423
+ with accelerator.accumulate(controlnet):
424
+ with torch.no_grad():
425
+ if "latents" in batch and batch["latents"] is not None:
426
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
427
+ else:
428
+ # latentに変換
429
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
430
+ latents = latents * 0.18215
431
+ b_size = latents.shape[0]
432
+
433
+ input_ids = batch["input_ids"].to(accelerator.device)
434
+ encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
435
+
436
+ # Sample noise that we'll add to the latents
437
+ noise = torch.randn_like(latents, device=latents.device)
438
+ if args.noise_offset:
439
+ noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
440
+ elif args.multires_noise_iterations:
441
+ noise = pyramid_noise_like(
442
+ noise,
443
+ latents.device,
444
+ args.multires_noise_iterations,
445
+ args.multires_noise_discount,
446
+ )
447
+
448
+ # Sample a random timestep for each image
449
+ timesteps, huber_c = train_util.get_timesteps_and_huber_c(
450
+ args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
451
+ )
452
+
453
+ # Add noise to the latents according to the noise magnitude at each timestep
454
+ # (this is the forward diffusion process)
455
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
456
+
457
+ controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
458
+
459
+ with accelerator.autocast():
460
+ down_block_res_samples, mid_block_res_sample = controlnet(
461
+ noisy_latents,
462
+ timesteps,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ controlnet_cond=controlnet_image,
465
+ return_dict=False,
466
+ )
467
+
468
+ # Predict the noise residual
469
+ noise_pred = unet(
470
+ noisy_latents,
471
+ timesteps,
472
+ encoder_hidden_states,
473
+ down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
474
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
475
+ ).sample
476
+
477
+ if args.v_parameterization:
478
+ # v-parameterization training
479
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
480
+ else:
481
+ target = noise
482
+
483
+ loss = train_util.conditional_loss(
484
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
485
+ )
486
+ loss = loss.mean([1, 2, 3])
487
+
488
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
489
+ loss = loss * loss_weights
490
+
491
+ if args.min_snr_gamma:
492
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
493
+
494
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
495
+
496
+ accelerator.backward(loss)
497
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
498
+ params_to_clip = controlnet.parameters()
499
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
500
+
501
+ optimizer.step()
502
+ lr_scheduler.step()
503
+ optimizer.zero_grad(set_to_none=True)
504
+
505
+ # Checks if the accelerator has performed an optimization step behind the scenes
506
+ if accelerator.sync_gradients:
507
+ progress_bar.update(1)
508
+ global_step += 1
509
+
510
+ train_util.sample_images(
511
+ accelerator,
512
+ args,
513
+ None,
514
+ global_step,
515
+ accelerator.device,
516
+ vae,
517
+ tokenizer,
518
+ text_encoder,
519
+ unet,
520
+ controlnet=controlnet,
521
+ )
522
+
523
+ # 指定ステップごとにモデルを保存
524
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
525
+ accelerator.wait_for_everyone()
526
+ if accelerator.is_main_process:
527
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
528
+ save_model(
529
+ ckpt_name,
530
+ accelerator.unwrap_model(controlnet),
531
+ )
532
+
533
+ if args.save_state:
534
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
535
+
536
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
537
+ if remove_step_no is not None:
538
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
539
+ remove_model(remove_ckpt_name)
540
+
541
+ current_loss = loss.detach().item()
542
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
543
+ avr_loss: float = loss_recorder.moving_average
544
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
545
+ progress_bar.set_postfix(**logs)
546
+
547
+ if args.logging_dir is not None:
548
+ logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
549
+ accelerator.log(logs, step=global_step)
550
+
551
+ if global_step >= args.max_train_steps:
552
+ break
553
+
554
+ if args.logging_dir is not None:
555
+ logs = {"loss/epoch": loss_recorder.moving_average}
556
+ accelerator.log(logs, step=epoch + 1)
557
+
558
+ accelerator.wait_for_everyone()
559
+
560
+ # 指定エポックごとにモデルを保存
561
+ if args.save_every_n_epochs is not None:
562
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
563
+ if is_main_process and saving:
564
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
565
+ save_model(ckpt_name, accelerator.unwrap_model(controlnet))
566
+
567
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
568
+ if remove_epoch_no is not None:
569
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
570
+ remove_model(remove_ckpt_name)
571
+
572
+ if args.save_state:
573
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
574
+
575
+ train_util.sample_images(
576
+ accelerator,
577
+ args,
578
+ epoch + 1,
579
+ global_step,
580
+ accelerator.device,
581
+ vae,
582
+ tokenizer,
583
+ text_encoder,
584
+ unet,
585
+ controlnet=controlnet,
586
+ )
587
+
588
+ # end of epoch
589
+ if is_main_process:
590
+ controlnet = accelerator.unwrap_model(controlnet)
591
+
592
+ accelerator.end_training()
593
+
594
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
595
+ train_util.save_state_on_train_end(args, accelerator)
596
+
597
+ # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
598
+
599
+ if is_main_process:
600
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
601
+ save_model(ckpt_name, controlnet, force_sync_upload=True)
602
+
603
+ logger.info("model saved.")
604
+
605
+
606
+ def setup_parser() -> argparse.ArgumentParser:
607
+ parser = argparse.ArgumentParser()
608
+
609
+ add_logging_arguments(parser)
610
+ train_util.add_sd_models_arguments(parser)
611
+ train_util.add_dataset_arguments(parser, False, True, True)
612
+ train_util.add_training_arguments(parser, False)
613
+ deepspeed_utils.add_deepspeed_arguments(parser)
614
+ train_util.add_optimizer_arguments(parser)
615
+ config_util.add_config_arguments(parser)
616
+ custom_train_functions.add_custom_train_arguments(parser)
617
+
618
+ parser.add_argument(
619
+ "--save_model_as",
620
+ type=str,
621
+ default="safetensors",
622
+ choices=[None, "ckpt", "pt", "safetensors"],
623
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
624
+ )
625
+ parser.add_argument(
626
+ "--controlnet_model_name_or_path",
627
+ type=str,
628
+ default=None,
629
+ help="controlnet model name or path / controlnetのモデル名またはパス",
630
+ )
631
+ parser.add_argument(
632
+ "--conditioning_data_dir",
633
+ type=str,
634
+ default=None,
635
+ help="conditioning data directory / 条件付けデータのディレクトリ",
636
+ )
637
+
638
+ return parser
639
+
640
+
641
+ if __name__ == "__main__":
642
+ parser = setup_parser()
643
+
644
+ args = parser.parse_args()
645
+ train_util.verify_command_line_training_args(args)
646
+ args = train_util.read_config_from_file(args, parser)
647
+
648
+ train(args)
train_db.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamBooth training
2
+ # XXX dropped option: fine_tune
3
+
4
+ import argparse
5
+ import itertools
6
+ import math
7
+ import os
8
+ from multiprocessing import Value
9
+ import toml
10
+
11
+ from tqdm import tqdm
12
+
13
+ import torch
14
+ from library import deepspeed_utils
15
+ from library.device_utils import init_ipex, clean_memory_on_device
16
+
17
+
18
+ init_ipex()
19
+
20
+ from accelerate.utils import set_seed
21
+ from diffusers import DDPMScheduler
22
+
23
+ import library.train_util as train_util
24
+ import library.config_util as config_util
25
+ from library.config_util import (
26
+ ConfigSanitizer,
27
+ BlueprintGenerator,
28
+ )
29
+ import library.custom_train_functions as custom_train_functions
30
+ from library.custom_train_functions import (
31
+ apply_snr_weight,
32
+ get_weighted_text_embeddings,
33
+ prepare_scheduler_for_custom_training,
34
+ pyramid_noise_like,
35
+ apply_noise_offset,
36
+ scale_v_prediction_loss_like_noise_prediction,
37
+ apply_debiased_estimation,
38
+ apply_masked_loss,
39
+ )
40
+ from library.utils import setup_logging, add_logging_arguments
41
+
42
+ setup_logging()
43
+ import logging
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # perlin_noise,
48
+
49
+
50
+ def train(args):
51
+ train_util.verify_training_args(args)
52
+ train_util.prepare_dataset_args(args, False)
53
+ deepspeed_utils.prepare_deepspeed_args(args)
54
+ setup_logging(args, reset=True)
55
+
56
+ cache_latents = args.cache_latents
57
+
58
+ if args.seed is not None:
59
+ set_seed(args.seed) # 乱数系列を初期化する
60
+
61
+ tokenizer = train_util.load_tokenizer(args)
62
+
63
+ # データセットを準備する
64
+ if args.dataset_class is None:
65
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
66
+ if args.dataset_config is not None:
67
+ logger.info(f"Load dataset config from {args.dataset_config}")
68
+ user_config = config_util.load_user_config(args.dataset_config)
69
+ ignored = ["train_data_dir", "reg_data_dir"]
70
+ if any(getattr(args, attr) is not None for attr in ignored):
71
+ logger.warning(
72
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
73
+ ", ".join(ignored)
74
+ )
75
+ )
76
+ else:
77
+ user_config = {
78
+ "datasets": [
79
+ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
80
+ ]
81
+ }
82
+
83
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
84
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
85
+ else:
86
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
87
+
88
+ current_epoch = Value("i", 0)
89
+ current_step = Value("i", 0)
90
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
91
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
92
+
93
+ if args.no_token_padding:
94
+ train_dataset_group.disable_token_padding()
95
+
96
+ train_dataset_group.verify_bucket_reso_steps(64)
97
+
98
+ if args.debug_dataset:
99
+ train_util.debug_dataset(train_dataset_group)
100
+ return
101
+
102
+ if cache_latents:
103
+ assert (
104
+ train_dataset_group.is_latent_cacheable()
105
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
106
+
107
+ # acceleratorを準備する
108
+ logger.info("prepare accelerator")
109
+
110
+ if args.gradient_accumulation_steps > 1:
111
+ logger.warning(
112
+ f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
113
+ )
114
+ logger.warning(
115
+ f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
116
+ )
117
+
118
+ accelerator = train_util.prepare_accelerator(args)
119
+
120
+ # mixed precisionに対応した型を用意しておき適宜castする
121
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
122
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
123
+
124
+ # モデルを読み込む
125
+ text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
126
+
127
+ # verify load/save model formats
128
+ if load_stable_diffusion_format:
129
+ src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
130
+ src_diffusers_model_path = None
131
+ else:
132
+ src_stable_diffusion_ckpt = None
133
+ src_diffusers_model_path = args.pretrained_model_name_or_path
134
+
135
+ if args.save_model_as is None:
136
+ save_stable_diffusion_format = load_stable_diffusion_format
137
+ use_safetensors = args.use_safetensors
138
+ else:
139
+ save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
140
+ use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
141
+
142
+ # モデルに xformers とか memory efficient attention を組み込む
143
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
144
+
145
+ # 学習を準備する
146
+ if cache_latents:
147
+ vae.to(accelerator.device, dtype=vae_dtype)
148
+ vae.requires_grad_(False)
149
+ vae.eval()
150
+ with torch.no_grad():
151
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
152
+ vae.to("cpu")
153
+ clean_memory_on_device(accelerator.device)
154
+
155
+ accelerator.wait_for_everyone()
156
+
157
+ # 学習を準備する:モデルを適切な状態にする
158
+ train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
159
+ unet.requires_grad_(True) # 念のため追加
160
+ text_encoder.requires_grad_(train_text_encoder)
161
+ if not train_text_encoder:
162
+ accelerator.print("Text Encoder is not trained.")
163
+
164
+ if args.gradient_checkpointing:
165
+ unet.enable_gradient_checkpointing()
166
+ text_encoder.gradient_checkpointing_enable()
167
+
168
+ if not cache_latents:
169
+ vae.requires_grad_(False)
170
+ vae.eval()
171
+ vae.to(accelerator.device, dtype=weight_dtype)
172
+
173
+ # 学習に必要なクラスを準備する
174
+ accelerator.print("prepare optimizer, data loader etc.")
175
+ if train_text_encoder:
176
+ if args.learning_rate_te is None:
177
+ # wightout list, adamw8bit is crashed
178
+ trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
179
+ else:
180
+ trainable_params = [
181
+ {"params": list(unet.parameters()), "lr": args.learning_rate},
182
+ {"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
183
+ ]
184
+ else:
185
+ trainable_params = unet.parameters()
186
+
187
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
188
+
189
+ # dataloaderを準備する
190
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
191
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
192
+ train_dataloader = torch.utils.data.DataLoader(
193
+ train_dataset_group,
194
+ batch_size=1,
195
+ shuffle=True,
196
+ collate_fn=collator,
197
+ num_workers=n_workers,
198
+ persistent_workers=args.persistent_data_loader_workers,
199
+ )
200
+
201
+ # 学習ステップ数を計算する
202
+ if args.max_train_epochs is not None:
203
+ args.max_train_steps = args.max_train_epochs * math.ceil(
204
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
205
+ )
206
+ accelerator.print(
207
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
208
+ )
209
+
210
+ # データセット側にも学習ステップを送信
211
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
212
+
213
+ if args.stop_text_encoder_training is None:
214
+ args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
215
+
216
+ # lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
217
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
218
+
219
+ # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
220
+ if args.full_fp16:
221
+ assert (
222
+ args.mixed_precision == "fp16"
223
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
224
+ accelerator.print("enable full fp16 training.")
225
+ unet.to(weight_dtype)
226
+ text_encoder.to(weight_dtype)
227
+
228
+ # acceleratorがなんかよろしくやってくれるらしい
229
+ if args.deepspeed:
230
+ if args.train_text_encoder:
231
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
232
+ else:
233
+ ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
234
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
235
+ ds_model, optimizer, train_dataloader, lr_scheduler
236
+ )
237
+ training_models = [ds_model]
238
+
239
+ else:
240
+ if train_text_encoder:
241
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
242
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
243
+ )
244
+ training_models = [unet, text_encoder]
245
+ else:
246
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
247
+ training_models = [unet]
248
+
249
+ if not train_text_encoder:
250
+ text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
251
+
252
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
253
+ if args.full_fp16:
254
+ train_util.patch_accelerator_for_fp16_training(accelerator)
255
+
256
+ # resumeする
257
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
258
+
259
+ # epoch数を計算する
260
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
261
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
262
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
263
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
264
+
265
+ # 学習する
266
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
267
+ accelerator.print("running training / 学習開始")
268
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
269
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
270
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
271
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
272
+ accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
273
+ accelerator.print(
274
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
275
+ )
276
+ accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
277
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
278
+
279
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
280
+ global_step = 0
281
+
282
+ noise_scheduler = DDPMScheduler(
283
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
284
+ )
285
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
286
+ if args.zero_terminal_snr:
287
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
288
+
289
+ if accelerator.is_main_process:
290
+ init_kwargs = {}
291
+ if args.wandb_run_name:
292
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
293
+ if args.log_tracker_config is not None:
294
+ init_kwargs = toml.load(args.log_tracker_config)
295
+ accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
296
+
297
+ # For --sample_at_first
298
+ train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
299
+
300
+ loss_recorder = train_util.LossRecorder()
301
+ for epoch in range(num_train_epochs):
302
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
303
+ current_epoch.value = epoch + 1
304
+
305
+ # 指定したステップ数までText Encoderを学習する:epoch最初の状態
306
+ unet.train()
307
+ # train==True is required to enable gradient_checkpointing
308
+ if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
309
+ text_encoder.train()
310
+
311
+ for step, batch in enumerate(train_dataloader):
312
+ current_step.value = global_step
313
+ # 指定したステップ数でText Encoderの学習を止める
314
+ if global_step == args.stop_text_encoder_training:
315
+ accelerator.print(f"stop text encoder training at step {global_step}")
316
+ if not args.gradient_checkpointing:
317
+ text_encoder.train(False)
318
+ text_encoder.requires_grad_(False)
319
+ if len(training_models) == 2:
320
+ training_models = training_models[0] # remove text_encoder from training_models
321
+
322
+ with accelerator.accumulate(*training_models):
323
+ with torch.no_grad():
324
+ # latentに変換
325
+ if cache_latents:
326
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
327
+ else:
328
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
329
+ latents = latents * 0.18215
330
+ b_size = latents.shape[0]
331
+
332
+ # Get the text embedding for conditioning
333
+ with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
334
+ if args.weighted_captions:
335
+ encoder_hidden_states = get_weighted_text_embeddings(
336
+ tokenizer,
337
+ text_encoder,
338
+ batch["captions"],
339
+ accelerator.device,
340
+ args.max_token_length // 75 if args.max_token_length else 1,
341
+ clip_skip=args.clip_skip,
342
+ )
343
+ else:
344
+ input_ids = batch["input_ids"].to(accelerator.device)
345
+ encoder_hidden_states = train_util.get_hidden_states(
346
+ args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
347
+ )
348
+
349
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
350
+ # with noise offset and/or multires noise if specified
351
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
352
+
353
+ # Predict the noise residual
354
+ with accelerator.autocast():
355
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
356
+
357
+ if args.v_parameterization:
358
+ # v-parameterization training
359
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
360
+ else:
361
+ target = noise
362
+
363
+ loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
364
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
365
+ loss = apply_masked_loss(loss, batch)
366
+ loss = loss.mean([1, 2, 3])
367
+
368
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
369
+ loss = loss * loss_weights
370
+
371
+ if args.min_snr_gamma:
372
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
373
+ if args.scale_v_pred_loss_like_noise_pred:
374
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
375
+ if args.debiased_estimation_loss:
376
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
377
+
378
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
379
+
380
+ accelerator.backward(loss)
381
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
382
+ if train_text_encoder:
383
+ params_to_clip = itertools.chain(unet.parameters(), text_encoder.parameters())
384
+ else:
385
+ params_to_clip = unet.parameters()
386
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
387
+
388
+ optimizer.step()
389
+ lr_scheduler.step()
390
+ optimizer.zero_grad(set_to_none=True)
391
+
392
+ # Checks if the accelerator has performed an optimization step behind the scenes
393
+ if accelerator.sync_gradients:
394
+ progress_bar.update(1)
395
+ global_step += 1
396
+
397
+ train_util.sample_images(
398
+ accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
399
+ )
400
+
401
+ # 指定ステップごとにモデルを保存
402
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
403
+ accelerator.wait_for_everyone()
404
+ if accelerator.is_main_process:
405
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
406
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
407
+ args,
408
+ False,
409
+ accelerator,
410
+ src_path,
411
+ save_stable_diffusion_format,
412
+ use_safetensors,
413
+ save_dtype,
414
+ epoch,
415
+ num_train_epochs,
416
+ global_step,
417
+ accelerator.unwrap_model(text_encoder),
418
+ accelerator.unwrap_model(unet),
419
+ vae,
420
+ )
421
+
422
+ current_loss = loss.detach().item()
423
+ if args.logging_dir is not None:
424
+ logs = {"loss": current_loss}
425
+ train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
426
+ accelerator.log(logs, step=global_step)
427
+
428
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
429
+ avr_loss: float = loss_recorder.moving_average
430
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
431
+ progress_bar.set_postfix(**logs)
432
+
433
+ if global_step >= args.max_train_steps:
434
+ break
435
+
436
+ if args.logging_dir is not None:
437
+ logs = {"loss/epoch": loss_recorder.moving_average}
438
+ accelerator.log(logs, step=epoch + 1)
439
+
440
+ accelerator.wait_for_everyone()
441
+
442
+ if args.save_every_n_epochs is not None:
443
+ if accelerator.is_main_process:
444
+ # checking for saving is in util
445
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
446
+ train_util.save_sd_model_on_epoch_end_or_stepwise(
447
+ args,
448
+ True,
449
+ accelerator,
450
+ src_path,
451
+ save_stable_diffusion_format,
452
+ use_safetensors,
453
+ save_dtype,
454
+ epoch,
455
+ num_train_epochs,
456
+ global_step,
457
+ accelerator.unwrap_model(text_encoder),
458
+ accelerator.unwrap_model(unet),
459
+ vae,
460
+ )
461
+
462
+ train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
463
+
464
+ is_main_process = accelerator.is_main_process
465
+ if is_main_process:
466
+ unet = accelerator.unwrap_model(unet)
467
+ text_encoder = accelerator.unwrap_model(text_encoder)
468
+
469
+ accelerator.end_training()
470
+
471
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
472
+ train_util.save_state_on_train_end(args, accelerator)
473
+
474
+ del accelerator # この後メモリを使うのでこれは消す
475
+
476
+ if is_main_process:
477
+ src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
478
+ train_util.save_sd_model_on_train_end(
479
+ args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
480
+ )
481
+ logger.info("model saved.")
482
+
483
+
484
+ def setup_parser() -> argparse.ArgumentParser:
485
+ parser = argparse.ArgumentParser()
486
+
487
+ add_logging_arguments(parser)
488
+ train_util.add_sd_models_arguments(parser)
489
+ train_util.add_dataset_arguments(parser, True, False, True)
490
+ train_util.add_training_arguments(parser, True)
491
+ train_util.add_masked_loss_arguments(parser)
492
+ deepspeed_utils.add_deepspeed_arguments(parser)
493
+ train_util.add_sd_saving_arguments(parser)
494
+ train_util.add_optimizer_arguments(parser)
495
+ config_util.add_config_arguments(parser)
496
+ custom_train_functions.add_custom_train_arguments(parser)
497
+
498
+ parser.add_argument(
499
+ "--learning_rate_te",
500
+ type=float,
501
+ default=None,
502
+ help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
503
+ )
504
+ parser.add_argument(
505
+ "--no_token_padding",
506
+ action="store_true",
507
+ help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)",
508
+ )
509
+ parser.add_argument(
510
+ "--stop_text_encoder_training",
511
+ type=int,
512
+ default=None,
513
+ help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
514
+ )
515
+ parser.add_argument(
516
+ "--no_half_vae",
517
+ action="store_true",
518
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
519
+ )
520
+
521
+ return parser
522
+
523
+
524
+ if __name__ == "__main__":
525
+ parser = setup_parser()
526
+
527
+ args = parser.parse_args()
528
+ train_util.verify_command_line_training_args(args)
529
+ args = train_util.read_config_from_file(args, parser)
530
+
531
+ train(args)
train_network.py ADDED
@@ -0,0 +1,1250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import argparse
3
+ import math
4
+ import os
5
+ import sys
6
+ import random
7
+ import time
8
+ import json
9
+ from multiprocessing import Value
10
+ import toml
11
+
12
+ from tqdm import tqdm
13
+
14
+ import torch
15
+ from library.device_utils import init_ipex, clean_memory_on_device
16
+
17
+ init_ipex()
18
+
19
+ from accelerate.utils import set_seed
20
+ from diffusers import DDPMScheduler
21
+ from library import deepspeed_utils, model_util
22
+
23
+ import library.train_util as train_util
24
+ from library.train_util import DreamBoothDataset
25
+ import library.config_util as config_util
26
+ from library.config_util import (
27
+ ConfigSanitizer,
28
+ BlueprintGenerator,
29
+ )
30
+ import library.huggingface_util as huggingface_util
31
+ import library.custom_train_functions as custom_train_functions
32
+ from library.custom_train_functions import (
33
+ apply_snr_weight,
34
+ get_weighted_text_embeddings,
35
+ prepare_scheduler_for_custom_training,
36
+ scale_v_prediction_loss_like_noise_prediction,
37
+ add_v_prediction_like_loss,
38
+ apply_debiased_estimation,
39
+ apply_masked_loss,
40
+ )
41
+ from library.utils import setup_logging, add_logging_arguments
42
+
43
+ setup_logging()
44
+ import logging
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class NetworkTrainer:
50
+ def __init__(self):
51
+ self.vae_scale_factor = 0.18215
52
+ self.is_sdxl = False
53
+
54
+ # TODO 他のスクリプトと共通化する
55
+ def generate_step_logs(
56
+ self,
57
+ args: argparse.Namespace,
58
+ current_loss,
59
+ avr_loss,
60
+ lr_scheduler,
61
+ lr_descriptions,
62
+ keys_scaled=None,
63
+ mean_norm=None,
64
+ maximum_norm=None,
65
+ ):
66
+ logs = {"loss/current": current_loss, "loss/average": avr_loss}
67
+
68
+ if keys_scaled is not None:
69
+ logs["max_norm/keys_scaled"] = keys_scaled
70
+ logs["max_norm/average_key_norm"] = mean_norm
71
+ logs["max_norm/max_key_norm"] = maximum_norm
72
+
73
+ lrs = lr_scheduler.get_last_lr()
74
+ for i, lr in enumerate(lrs):
75
+ if lr_descriptions is not None:
76
+ lr_desc = lr_descriptions[i]
77
+ else:
78
+ idx = i - (0 if args.network_train_unet_only else -1)
79
+ if idx == -1:
80
+ lr_desc = "textencoder"
81
+ else:
82
+ if len(lrs) > 2:
83
+ lr_desc = f"group{idx}"
84
+ else:
85
+ lr_desc = "unet"
86
+
87
+ logs[f"lr/{lr_desc}"] = lr
88
+
89
+ if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
90
+ # tracking d*lr value
91
+ logs[f"lr/d*lr/{lr_desc}"] = (
92
+ lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
93
+ )
94
+
95
+ return logs
96
+
97
+ def assert_extra_args(self, args, train_dataset_group):
98
+ train_dataset_group.verify_bucket_reso_steps(64)
99
+
100
+ def load_target_model(self, args, weight_dtype, accelerator):
101
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
102
+ return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
103
+
104
+ def load_tokenizer(self, args):
105
+ tokenizer = train_util.load_tokenizer(args)
106
+ return tokenizer
107
+
108
+ def is_text_encoder_outputs_cached(self, args):
109
+ return False
110
+
111
+ def is_train_text_encoder(self, args):
112
+ return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
113
+
114
+ def cache_text_encoder_outputs_if_needed(
115
+ self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
116
+ ):
117
+ for t_enc in text_encoders:
118
+ t_enc.to(accelerator.device, dtype=weight_dtype)
119
+
120
+ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
121
+ input_ids = batch["input_ids"].to(accelerator.device)
122
+ encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], weight_dtype)
123
+ return encoder_hidden_states
124
+
125
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
126
+ noise_pred = unet(noisy_latents, timesteps, text_conds).sample
127
+ return noise_pred
128
+
129
+ def all_reduce_network(self, accelerator, network):
130
+ for param in network.parameters():
131
+ if param.grad is not None:
132
+ param.grad = accelerator.reduce(param.grad, reduction="mean")
133
+
134
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
135
+ train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
136
+
137
+ def train(self, args):
138
+ session_id = random.randint(0, 2**32)
139
+ training_started_at = time.time()
140
+ train_util.verify_training_args(args)
141
+ train_util.prepare_dataset_args(args, True)
142
+ deepspeed_utils.prepare_deepspeed_args(args)
143
+ setup_logging(args, reset=True)
144
+
145
+ cache_latents = args.cache_latents
146
+ use_dreambooth_method = args.in_json is None
147
+ use_user_config = args.dataset_config is not None
148
+
149
+ if args.seed is None:
150
+ args.seed = random.randint(0, 2**32)
151
+ set_seed(args.seed)
152
+
153
+ # tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
154
+ tokenizer = self.load_tokenizer(args)
155
+ tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
156
+
157
+ # データセットを準備する
158
+ if args.dataset_class is None:
159
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
160
+ if use_user_config:
161
+ logger.info(f"Loading dataset config from {args.dataset_config}")
162
+ user_config = config_util.load_user_config(args.dataset_config)
163
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
164
+ if any(getattr(args, attr) is not None for attr in ignored):
165
+ logger.warning(
166
+ "ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
167
+ ", ".join(ignored)
168
+ )
169
+ )
170
+ else:
171
+ if use_dreambooth_method:
172
+ logger.info("Using DreamBooth method.")
173
+ user_config = {
174
+ "datasets": [
175
+ {
176
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
177
+ args.train_data_dir, args.reg_data_dir
178
+ )
179
+ }
180
+ ]
181
+ }
182
+ else:
183
+ logger.info("Training with captions.")
184
+ user_config = {
185
+ "datasets": [
186
+ {
187
+ "subsets": [
188
+ {
189
+ "image_dir": args.train_data_dir,
190
+ "metadata_file": args.in_json,
191
+ }
192
+ ]
193
+ }
194
+ ]
195
+ }
196
+
197
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
198
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
199
+ else:
200
+ # use arbitrary dataset class
201
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
202
+
203
+ current_epoch = Value("i", 0)
204
+ current_step = Value("i", 0)
205
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
206
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
207
+
208
+ if args.debug_dataset:
209
+ train_util.debug_dataset(train_dataset_group)
210
+ return
211
+ if len(train_dataset_group) == 0:
212
+ logger.error(
213
+ "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
214
+ )
215
+ return
216
+
217
+ if cache_latents:
218
+ assert (
219
+ train_dataset_group.is_latent_cacheable()
220
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
221
+
222
+ self.assert_extra_args(args, train_dataset_group)
223
+
224
+ # acceleratorを準備する
225
+ logger.info("preparing accelerator")
226
+ accelerator = train_util.prepare_accelerator(args)
227
+ is_main_process = accelerator.is_main_process
228
+
229
+ # mixed precisionに対応した型を用意しておき適宜castする
230
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
231
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
232
+
233
+ # モデルを読み込む
234
+ model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
235
+
236
+ # text_encoder is List[CLIPTextModel] or CLIPTextModel
237
+ text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
238
+
239
+ # モデルに xformers とか memory efficient attention を組み込む
240
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
241
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
242
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
243
+
244
+ # 差分追加学習のためにモデルを読み込む
245
+ sys.path.append(os.path.dirname(__file__))
246
+ accelerator.print("import network module:", args.network_module)
247
+ network_module = importlib.import_module(args.network_module)
248
+
249
+ if args.base_weights is not None:
250
+ # base_weights が指定されている場合は、指定された重みを読み込みマージする
251
+ for i, weight_path in enumerate(args.base_weights):
252
+ if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
253
+ multiplier = 1.0
254
+ else:
255
+ multiplier = args.base_weights_multiplier[i]
256
+
257
+ accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}")
258
+
259
+ module, weights_sd = network_module.create_network_from_weights(
260
+ multiplier, weight_path, vae, text_encoder, unet, for_inference=True
261
+ )
262
+ module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
263
+
264
+ accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
265
+
266
+ # 学習を準備する
267
+ if cache_latents:
268
+ vae.to(accelerator.device, dtype=vae_dtype)
269
+ vae.requires_grad_(False)
270
+ vae.eval()
271
+ with torch.no_grad():
272
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
273
+ vae.to("cpu")
274
+ clean_memory_on_device(accelerator.device)
275
+
276
+ accelerator.wait_for_everyone()
277
+
278
+ # 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
279
+ # cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
280
+ self.cache_text_encoder_outputs_if_needed(
281
+ args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
282
+ )
283
+
284
+ # prepare network
285
+ net_kwargs = {}
286
+ if args.network_args is not None:
287
+ for net_arg in args.network_args:
288
+ key, value = net_arg.split("=")
289
+ net_kwargs[key] = value
290
+
291
+ # if a new network is added in future, add if ~ then blocks for each network (;'∀')
292
+ if args.dim_from_weights:
293
+ network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
294
+ else:
295
+ if "dropout" not in net_kwargs:
296
+ # workaround for LyCORIS (;^ω^)
297
+ net_kwargs["dropout"] = args.network_dropout
298
+
299
+ network = network_module.create_network(
300
+ 1.0,
301
+ args.network_dim,
302
+ args.network_alpha,
303
+ vae,
304
+ text_encoder,
305
+ unet,
306
+ neuron_dropout=args.network_dropout,
307
+ **net_kwargs,
308
+ )
309
+ if network is None:
310
+ return
311
+ network_has_multiplier = hasattr(network, "set_multiplier")
312
+
313
+ if hasattr(network, "prepare_network"):
314
+ network.prepare_network(args)
315
+ if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
316
+ logger.warning(
317
+ "warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
318
+ )
319
+ args.scale_weight_norms = False
320
+
321
+ train_unet = not args.network_train_text_encoder_only
322
+ train_text_encoder = self.is_train_text_encoder(args)
323
+ network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
324
+
325
+ if args.network_weights is not None:
326
+ # FIXME consider alpha of weights
327
+ info = network.load_weights(args.network_weights)
328
+ accelerator.print(f"load network weights from {args.network_weights}: {info}")
329
+
330
+ if args.gradient_checkpointing:
331
+ unet.enable_gradient_checkpointing()
332
+ for t_enc in text_encoders:
333
+ t_enc.gradient_checkpointing_enable()
334
+ del t_enc
335
+ network.enable_gradient_checkpointing() # may have no effect
336
+
337
+ # 学習に必要なクラスを準備する
338
+ accelerator.print("prepare optimizer, data loader etc.")
339
+
340
+ # 後方互換性を確保するよ
341
+ try:
342
+ results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
343
+ if type(results) is tuple:
344
+ trainable_params = results[0]
345
+ lr_descriptions = results[1]
346
+ else:
347
+ trainable_params = results
348
+ lr_descriptions = None
349
+ except TypeError as e:
350
+ # logger.warning(f"{e}")
351
+ # accelerator.print(
352
+ # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
353
+ # )
354
+ trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
355
+ lr_descriptions = None
356
+
357
+ # if len(trainable_params) == 0:
358
+ # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
359
+ # for params in trainable_params:
360
+ # for k, v in params.items():
361
+ # if type(v) == float:
362
+ # pass
363
+ # else:
364
+ # v = len(v)
365
+ # accelerator.print(f"trainable_params: {k} = {v}")
366
+
367
+ optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
368
+
369
+ # dataloaderを準備する
370
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
371
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
372
+
373
+ train_dataloader = torch.utils.data.DataLoader(
374
+ train_dataset_group,
375
+ batch_size=1,
376
+ shuffle=True,
377
+ collate_fn=collator,
378
+ num_workers=n_workers,
379
+ persistent_workers=args.persistent_data_loader_workers,
380
+ )
381
+
382
+ # 学習ステップ数を計算する
383
+ if args.max_train_epochs is not None:
384
+ args.max_train_steps = args.max_train_epochs * math.ceil(
385
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
386
+ )
387
+ accelerator.print(
388
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
389
+ )
390
+
391
+ # データセット側にも学習ステップを送信
392
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
393
+
394
+ # lr schedulerを用意する
395
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
396
+
397
+ # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
398
+ if args.full_fp16:
399
+ assert (
400
+ args.mixed_precision == "fp16"
401
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
402
+ accelerator.print("enable full fp16 training.")
403
+ network.to(weight_dtype)
404
+ elif args.full_bf16:
405
+ assert (
406
+ args.mixed_precision == "bf16"
407
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
408
+ accelerator.print("enable full bf16 training.")
409
+ network.to(weight_dtype)
410
+
411
+ unet_weight_dtype = te_weight_dtype = weight_dtype
412
+ # Experimental Feature: Put base model into fp8 to save vram
413
+ if args.fp8_base:
414
+ assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
415
+ assert (
416
+ args.mixed_precision != "no"
417
+ ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
418
+ accelerator.print("enable fp8 training.")
419
+ unet_weight_dtype = torch.float8_e4m3fn
420
+ te_weight_dtype = torch.float8_e4m3fn
421
+
422
+ unet.requires_grad_(False)
423
+ unet.to(dtype=unet_weight_dtype)
424
+ for t_enc in text_encoders:
425
+ t_enc.requires_grad_(False)
426
+
427
+ # in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
428
+ if t_enc.device.type != "cpu":
429
+ t_enc.to(dtype=te_weight_dtype)
430
+ # nn.Embedding not support FP8
431
+ t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
432
+
433
+ # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
434
+ if args.deepspeed:
435
+ ds_model = deepspeed_utils.prepare_deepspeed_model(
436
+ args,
437
+ unet=unet if train_unet else None,
438
+ text_encoder1=text_encoders[0] if train_text_encoder else None,
439
+ text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
440
+ network=network,
441
+ )
442
+ ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
443
+ ds_model, optimizer, train_dataloader, lr_scheduler
444
+ )
445
+ training_model = ds_model
446
+ else:
447
+ if train_unet:
448
+ unet = accelerator.prepare(unet)
449
+ else:
450
+ unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
451
+ if train_text_encoder:
452
+ if len(text_encoders) > 1:
453
+ text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
454
+ else:
455
+ text_encoder = accelerator.prepare(text_encoder)
456
+ text_encoders = [text_encoder]
457
+ else:
458
+ pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
459
+
460
+ network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
461
+ network, optimizer, train_dataloader, lr_scheduler
462
+ )
463
+ training_model = network
464
+
465
+ if args.gradient_checkpointing:
466
+ # according to TI example in Diffusers, train is required
467
+ unet.train()
468
+ for t_enc in text_encoders:
469
+ t_enc.train()
470
+
471
+ # set top parameter requires_grad = True for gradient checkpointing works
472
+ if train_text_encoder:
473
+ t_enc.text_model.embeddings.requires_grad_(True)
474
+
475
+ else:
476
+ unet.eval()
477
+ for t_enc in text_encoders:
478
+ t_enc.eval()
479
+
480
+ del t_enc
481
+
482
+ accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
483
+
484
+ if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
485
+ vae.requires_grad_(False)
486
+ vae.eval()
487
+ vae.to(accelerator.device, dtype=vae_dtype)
488
+
489
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
490
+ if args.full_fp16:
491
+ train_util.patch_accelerator_for_fp16_training(accelerator)
492
+
493
+ # before resuming make hook for saving/loading to save/load the network weights only
494
+ def save_model_hook(models, weights, output_dir):
495
+ # pop weights of other models than network to save only network weights
496
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
497
+ if accelerator.is_main_process or args.deepspeed:
498
+ remove_indices = []
499
+ for i, model in enumerate(models):
500
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
501
+ remove_indices.append(i)
502
+ for i in reversed(remove_indices):
503
+ if len(weights) > i:
504
+ weights.pop(i)
505
+ # print(f"save model hook: {len(weights)} weights will be saved")
506
+
507
+ # save current ecpoch and step
508
+ train_state_file = os.path.join(output_dir, "train_state.json")
509
+ # +1 is needed because the state is saved before current_step is set from global_step
510
+ logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
511
+ with open(train_state_file, "w", encoding="utf-8") as f:
512
+ json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
513
+
514
+ steps_from_state = None
515
+
516
+ def load_model_hook(models, input_dir):
517
+ # remove models except network
518
+ remove_indices = []
519
+ for i, model in enumerate(models):
520
+ if not isinstance(model, type(accelerator.unwrap_model(network))):
521
+ remove_indices.append(i)
522
+ for i in reversed(remove_indices):
523
+ models.pop(i)
524
+ # print(f"load model hook: {len(models)} models will be loaded")
525
+
526
+ # load current epoch and step to
527
+ nonlocal steps_from_state
528
+ train_state_file = os.path.join(input_dir, "train_state.json")
529
+ if os.path.exists(train_state_file):
530
+ with open(train_state_file, "r", encoding="utf-8") as f:
531
+ data = json.load(f)
532
+ steps_from_state = data["current_step"]
533
+ logger.info(f"load train state from {train_state_file}: {data}")
534
+
535
+ accelerator.register_save_state_pre_hook(save_model_hook)
536
+ accelerator.register_load_state_pre_hook(load_model_hook)
537
+
538
+ # resumeする
539
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
540
+
541
+ # epoch数を計算する
542
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
543
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
544
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
545
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
546
+
547
+ # 学習する
548
+ # TODO: find a way to handle total batch size when there are multiple datasets
549
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
550
+
551
+ accelerator.print("running training / 学習開始")
552
+ accelerator.print(f" num train images * repeats / 学���画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
553
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
554
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
555
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
556
+ accelerator.print(
557
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
558
+ )
559
+ # accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
560
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
561
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
562
+
563
+ # TODO refactor metadata creation and move to util
564
+ metadata = {
565
+ "ss_session_id": session_id, # random integer indicating which group of epochs the model came from
566
+ "ss_training_started_at": training_started_at, # unix timestamp
567
+ "ss_output_name": args.output_name,
568
+ "ss_learning_rate": args.learning_rate,
569
+ "ss_text_encoder_lr": args.text_encoder_lr,
570
+ "ss_unet_lr": args.unet_lr,
571
+ "ss_num_train_images": train_dataset_group.num_train_images,
572
+ "ss_num_reg_images": train_dataset_group.num_reg_images,
573
+ "ss_num_batches_per_epoch": len(train_dataloader),
574
+ "ss_num_epochs": num_train_epochs,
575
+ "ss_gradient_checkpointing": args.gradient_checkpointing,
576
+ "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
577
+ "ss_max_train_steps": args.max_train_steps,
578
+ "ss_lr_warmup_steps": args.lr_warmup_steps,
579
+ "ss_lr_scheduler": args.lr_scheduler,
580
+ "ss_network_module": args.network_module,
581
+ "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
582
+ "ss_network_alpha": args.network_alpha, # some networks may not have alpha
583
+ "ss_network_dropout": args.network_dropout, # some networks may not have dropout
584
+ "ss_mixed_precision": args.mixed_precision,
585
+ "ss_full_fp16": bool(args.full_fp16),
586
+ "ss_v2": bool(args.v2),
587
+ "ss_base_model_version": model_version,
588
+ "ss_clip_skip": args.clip_skip,
589
+ "ss_max_token_length": args.max_token_length,
590
+ "ss_cache_latents": bool(args.cache_latents),
591
+ "ss_seed": args.seed,
592
+ "ss_lowram": args.lowram,
593
+ "ss_noise_offset": args.noise_offset,
594
+ "ss_multires_noise_iterations": args.multires_noise_iterations,
595
+ "ss_multires_noise_discount": args.multires_noise_discount,
596
+ "ss_adaptive_noise_scale": args.adaptive_noise_scale,
597
+ "ss_zero_terminal_snr": args.zero_terminal_snr,
598
+ "ss_training_comment": args.training_comment, # will not be updated after training
599
+ "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
600
+ "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
601
+ "ss_max_grad_norm": args.max_grad_norm,
602
+ "ss_caption_dropout_rate": args.caption_dropout_rate,
603
+ "ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
604
+ "ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
605
+ "ss_face_crop_aug_range": args.face_crop_aug_range,
606
+ "ss_prior_loss_weight": args.prior_loss_weight,
607
+ "ss_min_snr_gamma": args.min_snr_gamma,
608
+ "ss_scale_weight_norms": args.scale_weight_norms,
609
+ "ss_ip_noise_gamma": args.ip_noise_gamma,
610
+ "ss_debiased_estimation": bool(args.debiased_estimation_loss),
611
+ "ss_noise_offset_random_strength": args.noise_offset_random_strength,
612
+ "ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
613
+ "ss_loss_type": args.loss_type,
614
+ "ss_huber_schedule": args.huber_schedule,
615
+ "ss_huber_c": args.huber_c,
616
+ }
617
+
618
+ if use_user_config:
619
+ # save metadata of multiple datasets
620
+ # NOTE: pack "ss_datasets" value as json one time
621
+ # or should also pack nested collections as json?
622
+ datasets_metadata = []
623
+ tag_frequency = {} # merge tag frequency for metadata editor
624
+ dataset_dirs_info = {} # merge subset dirs for metadata editor
625
+
626
+ for dataset in train_dataset_group.datasets:
627
+ is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
628
+ dataset_metadata = {
629
+ "is_dreambooth": is_dreambooth_dataset,
630
+ "batch_size_per_device": dataset.batch_size,
631
+ "num_train_images": dataset.num_train_images, # includes repeating
632
+ "num_reg_images": dataset.num_reg_images,
633
+ "resolution": (dataset.width, dataset.height),
634
+ "enable_bucket": bool(dataset.enable_bucket),
635
+ "min_bucket_reso": dataset.min_bucket_reso,
636
+ "max_bucket_reso": dataset.max_bucket_reso,
637
+ "tag_frequency": dataset.tag_frequency,
638
+ "bucket_info": dataset.bucket_info,
639
+ }
640
+
641
+ subsets_metadata = []
642
+ for subset in dataset.subsets:
643
+ subset_metadata = {
644
+ "img_count": subset.img_count,
645
+ "num_repeats": subset.num_repeats,
646
+ "color_aug": bool(subset.color_aug),
647
+ "flip_aug": bool(subset.flip_aug),
648
+ "random_crop": bool(subset.random_crop),
649
+ "shuffle_caption": bool(subset.shuffle_caption),
650
+ "keep_tokens": subset.keep_tokens,
651
+ "keep_tokens_separator": subset.keep_tokens_separator,
652
+ "secondary_separator": subset.secondary_separator,
653
+ "enable_wildcard": bool(subset.enable_wildcard),
654
+ "caption_prefix": subset.caption_prefix,
655
+ "caption_suffix": subset.caption_suffix,
656
+ }
657
+
658
+ image_dir_or_metadata_file = None
659
+ if subset.image_dir:
660
+ image_dir = os.path.basename(subset.image_dir)
661
+ subset_metadata["image_dir"] = image_dir
662
+ image_dir_or_metadata_file = image_dir
663
+
664
+ if is_dreambooth_dataset:
665
+ subset_metadata["class_tokens"] = subset.class_tokens
666
+ subset_metadata["is_reg"] = subset.is_reg
667
+ if subset.is_reg:
668
+ image_dir_or_metadata_file = None # not merging reg dataset
669
+ else:
670
+ metadata_file = os.path.basename(subset.metadata_file)
671
+ subset_metadata["metadata_file"] = metadata_file
672
+ image_dir_or_metadata_file = metadata_file # may overwrite
673
+
674
+ subsets_metadata.append(subset_metadata)
675
+
676
+ # merge dataset dir: not reg subset only
677
+ # TODO update additional-network extension to show detailed dataset config from metadata
678
+ if image_dir_or_metadata_file is not None:
679
+ # datasets may have a certain dir multiple times
680
+ v = image_dir_or_metadata_file
681
+ i = 2
682
+ while v in dataset_dirs_info:
683
+ v = image_dir_or_metadata_file + f" ({i})"
684
+ i += 1
685
+ image_dir_or_metadata_file = v
686
+
687
+ dataset_dirs_info[image_dir_or_metadata_file] = {
688
+ "n_repeats": subset.num_repeats,
689
+ "img_count": subset.img_count,
690
+ }
691
+
692
+ dataset_metadata["subsets"] = subsets_metadata
693
+ datasets_metadata.append(dataset_metadata)
694
+
695
+ # merge tag frequency:
696
+ for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
697
+ # あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
698
+ # もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
699
+ # なので、ここで複数datasetの回数を合算してもあまり意味はない
700
+ if ds_dir_name in tag_frequency:
701
+ continue
702
+ tag_frequency[ds_dir_name] = ds_freq_for_dir
703
+
704
+ metadata["ss_datasets"] = json.dumps(datasets_metadata)
705
+ metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
706
+ metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
707
+ else:
708
+ # conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
709
+ assert (
710
+ len(train_dataset_group.datasets) == 1
711
+ ), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
712
+
713
+ dataset = train_dataset_group.datasets[0]
714
+
715
+ dataset_dirs_info = {}
716
+ reg_dataset_dirs_info = {}
717
+ if use_dreambooth_method:
718
+ for subset in dataset.subsets:
719
+ info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
720
+ info[os.path.basename(subset.image_dir)] = {"n_repeats": subset.num_repeats, "img_count": subset.img_count}
721
+ else:
722
+ for subset in dataset.subsets:
723
+ dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
724
+ "n_repeats": subset.num_repeats,
725
+ "img_count": subset.img_count,
726
+ }
727
+
728
+ metadata.update(
729
+ {
730
+ "ss_batch_size_per_device": args.train_batch_size,
731
+ "ss_total_batch_size": total_batch_size,
732
+ "ss_resolution": args.resolution,
733
+ "ss_color_aug": bool(args.color_aug),
734
+ "ss_flip_aug": bool(args.flip_aug),
735
+ "ss_random_crop": bool(args.random_crop),
736
+ "ss_shuffle_caption": bool(args.shuffle_caption),
737
+ "ss_enable_bucket": bool(dataset.enable_bucket),
738
+ "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
739
+ "ss_min_bucket_reso": dataset.min_bucket_reso,
740
+ "ss_max_bucket_reso": dataset.max_bucket_reso,
741
+ "ss_keep_tokens": args.keep_tokens,
742
+ "ss_dataset_dirs": json.dumps(dataset_dirs_info),
743
+ "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
744
+ "ss_tag_frequency": json.dumps(dataset.tag_frequency),
745
+ "ss_bucket_info": json.dumps(dataset.bucket_info),
746
+ }
747
+ )
748
+
749
+ # add extra args
750
+ if args.network_args:
751
+ metadata["ss_network_args"] = json.dumps(net_kwargs)
752
+
753
+ # model name and hash
754
+ if args.pretrained_model_name_or_path is not None:
755
+ sd_model_name = args.pretrained_model_name_or_path
756
+ if os.path.exists(sd_model_name):
757
+ metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
758
+ metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
759
+ sd_model_name = os.path.basename(sd_model_name)
760
+ metadata["ss_sd_model_name"] = sd_model_name
761
+
762
+ if args.vae is not None:
763
+ vae_name = args.vae
764
+ if os.path.exists(vae_name):
765
+ metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
766
+ metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
767
+ vae_name = os.path.basename(vae_name)
768
+ metadata["ss_vae_name"] = vae_name
769
+
770
+ metadata = {k: str(v) for k, v in metadata.items()}
771
+
772
+ # make minimum metadata for filtering
773
+ minimum_metadata = {}
774
+ for key in train_util.SS_METADATA_MINIMUM_KEYS:
775
+ if key in metadata:
776
+ minimum_metadata[key] = metadata[key]
777
+
778
+ # calculate steps to skip when resuming or starting from a specific step
779
+ initial_step = 0
780
+ if args.initial_epoch is not None or args.initial_step is not None:
781
+ # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
782
+ if steps_from_state is not None:
783
+ logger.warning(
784
+ "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
785
+ )
786
+ if args.initial_step is not None:
787
+ initial_step = args.initial_step
788
+ else:
789
+ # num steps per epoch is calculated by num_processes and gradient_accumulation_steps
790
+ initial_step = (args.initial_epoch - 1) * math.ceil(
791
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
792
+ )
793
+ else:
794
+ # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
795
+ if steps_from_state is not None:
796
+ initial_step = steps_from_state
797
+ steps_from_state = None
798
+
799
+ if initial_step > 0:
800
+ assert (
801
+ args.max_train_steps > initial_step
802
+ ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
803
+
804
+ progress_bar = tqdm(
805
+ range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
806
+ )
807
+
808
+ epoch_to_start = 0
809
+ if initial_step > 0:
810
+ if args.skip_until_initial_step:
811
+ # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
812
+ if not args.resume:
813
+ logger.info(
814
+ f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr scheduler���最初から始まります"
815
+ )
816
+ logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
817
+ initial_step *= args.gradient_accumulation_steps
818
+
819
+ # set epoch to start to make initial_step less than len(train_dataloader)
820
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
821
+ else:
822
+ # if not, only epoch no is skipped for informative purpose
823
+ epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824
+ initial_step = 0 # do not skip
825
+
826
+ global_step = 0
827
+
828
+ noise_scheduler = DDPMScheduler(
829
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
830
+ )
831
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
832
+ if args.zero_terminal_snr:
833
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
834
+
835
+ if accelerator.is_main_process:
836
+ init_kwargs = {}
837
+ if args.wandb_run_name:
838
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
839
+ if args.log_tracker_config is not None:
840
+ init_kwargs = toml.load(args.log_tracker_config)
841
+ accelerator.init_trackers(
842
+ "network_train" if args.log_tracker_name is None else args.log_tracker_name,
843
+ config=train_util.get_sanitized_config_or_none(args),
844
+ init_kwargs=init_kwargs,
845
+ )
846
+
847
+ loss_recorder = train_util.LossRecorder()
848
+ del train_dataset_group
849
+
850
+ # callback for step start
851
+ if hasattr(accelerator.unwrap_model(network), "on_step_start"):
852
+ on_step_start = accelerator.unwrap_model(network).on_step_start
853
+ else:
854
+ on_step_start = lambda *args, **kwargs: None
855
+
856
+ # function for saving/removing
857
+ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
858
+ os.makedirs(args.output_dir, exist_ok=True)
859
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
860
+
861
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
862
+ metadata["ss_training_finished_at"] = str(time.time())
863
+ metadata["ss_steps"] = str(steps)
864
+ metadata["ss_epoch"] = str(epoch_no)
865
+
866
+ metadata_to_save = minimum_metadata if args.no_metadata else metadata
867
+ sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
868
+ metadata_to_save.update(sai_metadata)
869
+
870
+ unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
871
+ if args.huggingface_repo_id is not None:
872
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
873
+
874
+ def remove_model(old_ckpt_name):
875
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
876
+ if os.path.exists(old_ckpt_file):
877
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
878
+ os.remove(old_ckpt_file)
879
+
880
+ # For --sample_at_first
881
+ self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
882
+
883
+ # training loop
884
+ if initial_step > 0: # only if skip_until_initial_step is specified
885
+ for skip_epoch in range(epoch_to_start): # skip epochs
886
+ logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
887
+ initial_step -= len(train_dataloader)
888
+ global_step = initial_step
889
+
890
+ for epoch in range(epoch_to_start, num_train_epochs):
891
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
892
+ current_epoch.value = epoch + 1
893
+
894
+ metadata["ss_epoch"] = str(epoch + 1)
895
+
896
+ accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
897
+
898
+ skipped_dataloader = None
899
+ if initial_step > 0:
900
+ skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
901
+ initial_step = 1
902
+
903
+ for step, batch in enumerate(skipped_dataloader or train_dataloader):
904
+ current_step.value = global_step
905
+ if initial_step > 0:
906
+ initial_step -= 1
907
+ continue
908
+
909
+ with accelerator.accumulate(training_model):
910
+ on_step_start(text_encoder, unet)
911
+
912
+ if "latents" in batch and batch["latents"] is not None:
913
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
914
+ else:
915
+ if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
916
+ with torch.no_grad():
917
+ # latentに変換
918
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
919
+ else:
920
+ chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)]
921
+ list_latents = []
922
+ for chunk in chunks:
923
+ with torch.no_grad():
924
+ # latentに変換
925
+ list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype))
926
+ latents = torch.cat(list_latents, dim=0)
927
+ # NaNが含まれていれば警告を表示し0に置き換える
928
+ if torch.any(torch.isnan(latents)):
929
+ accelerator.print("NaN found in latents, replacing with zeros")
930
+ latents = torch.nan_to_num(latents, 0, out=latents)
931
+ latents = latents * self.vae_scale_factor
932
+
933
+ # get multiplier for each sample
934
+ if network_has_multiplier:
935
+ multipliers = batch["network_multipliers"]
936
+ # if all multipliers are same, use single multiplier
937
+ if torch.all(multipliers == multipliers[0]):
938
+ multipliers = multipliers[0].item()
939
+ else:
940
+ raise NotImplementedError("multipliers for each sample is not supported yet")
941
+ # print(f"set multiplier: {multipliers}")
942
+ accelerator.unwrap_model(network).set_multiplier(multipliers)
943
+
944
+ with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
945
+ # Get the text embedding for conditioning
946
+ if args.weighted_captions:
947
+ text_encoder_conds = get_weighted_text_embeddings(
948
+ tokenizer,
949
+ text_encoder,
950
+ batch["captions"],
951
+ accelerator.device,
952
+ args.max_token_length // 75 if args.max_token_length else 1,
953
+ clip_skip=args.clip_skip,
954
+ )
955
+ else:
956
+ text_encoder_conds = self.get_text_cond(
957
+ args, accelerator, batch, tokenizers, text_encoders, weight_dtype
958
+ )
959
+
960
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
961
+ # with noise offset and/or multires noise if specified
962
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
963
+ args, noise_scheduler, latents
964
+ )
965
+
966
+ # ensure the hidden state will require grad
967
+ if args.gradient_checkpointing:
968
+ for x in noisy_latents:
969
+ x.requires_grad_(True)
970
+ for t in text_encoder_conds:
971
+ t.requires_grad_(True)
972
+
973
+ # Predict the noise residual
974
+ with accelerator.autocast():
975
+ noise_pred = self.call_unet(
976
+ args,
977
+ accelerator,
978
+ unet,
979
+ noisy_latents.requires_grad_(train_unet),
980
+ timesteps,
981
+ text_encoder_conds,
982
+ batch,
983
+ weight_dtype,
984
+ )
985
+
986
+ if args.v_parameterization:
987
+ # v-parameterization training
988
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
989
+ else:
990
+ target = noise
991
+
992
+ loss = train_util.conditional_loss(
993
+ noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
994
+ )
995
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
996
+ loss = apply_masked_loss(loss, batch)
997
+ loss = loss.mean([1, 2, 3])
998
+
999
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
1000
+ loss = loss * loss_weights
1001
+
1002
+ if args.min_snr_gamma:
1003
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
1004
+ if args.scale_v_pred_loss_like_noise_pred:
1005
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
1006
+ if args.v_pred_like_loss:
1007
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
1008
+ if args.debiased_estimation_loss:
1009
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
1010
+
1011
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
1012
+
1013
+ accelerator.backward(loss)
1014
+ if accelerator.sync_gradients:
1015
+ self.all_reduce_network(accelerator, network) # sync DDP grad manually
1016
+ if args.max_grad_norm != 0.0:
1017
+ params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
1018
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1019
+
1020
+ optimizer.step()
1021
+ lr_scheduler.step()
1022
+ optimizer.zero_grad(set_to_none=True)
1023
+
1024
+ if args.scale_weight_norms:
1025
+ keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
1026
+ args.scale_weight_norms, accelerator.device
1027
+ )
1028
+ max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
1029
+ else:
1030
+ keys_scaled, mean_norm, maximum_norm = None, None, None
1031
+
1032
+ # Checks if the accelerator has performed an optimization step behind the scenes
1033
+ if accelerator.sync_gradients:
1034
+ progress_bar.update(1)
1035
+ global_step += 1
1036
+
1037
+ self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
1038
+
1039
+ # 指定ステップごとにモデルを保存
1040
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
1041
+ accelerator.wait_for_everyone()
1042
+ if accelerator.is_main_process:
1043
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
1044
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
1045
+
1046
+ if args.save_state:
1047
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
1048
+
1049
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
1050
+ if remove_step_no is not None:
1051
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
1052
+ remove_model(remove_ckpt_name)
1053
+
1054
+ current_loss = loss.detach().item()
1055
+ loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
1056
+ avr_loss: float = loss_recorder.moving_average
1057
+ logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
1058
+ progress_bar.set_postfix(**logs)
1059
+
1060
+ if args.scale_weight_norms:
1061
+ progress_bar.set_postfix(**{**max_mean_logs, **logs})
1062
+
1063
+ if args.logging_dir is not None:
1064
+ logs = self.generate_step_logs(
1065
+ args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
1066
+ )
1067
+ accelerator.log(logs, step=global_step)
1068
+
1069
+ if global_step >= args.max_train_steps:
1070
+ break
1071
+
1072
+ if args.logging_dir is not None:
1073
+ logs = {"loss/epoch": loss_recorder.moving_average}
1074
+ accelerator.log(logs, step=epoch + 1)
1075
+
1076
+ accelerator.wait_for_everyone()
1077
+
1078
+ # 指定エポックごとにモデルを保存
1079
+ if args.save_every_n_epochs is not None:
1080
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
1081
+ if is_main_process and saving:
1082
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
1083
+ save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
1084
+
1085
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
1086
+ if remove_epoch_no is not None:
1087
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
1088
+ remove_model(remove_ckpt_name)
1089
+
1090
+ if args.save_state:
1091
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
1092
+
1093
+ self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
1094
+
1095
+ # end of epoch
1096
+
1097
+ # metadata["ss_epoch"] = str(num_train_epochs)
1098
+ metadata["ss_training_finished_at"] = str(time.time())
1099
+
1100
+ if is_main_process:
1101
+ network = accelerator.unwrap_model(network)
1102
+
1103
+ accelerator.end_training()
1104
+
1105
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
1106
+ train_util.save_state_on_train_end(args, accelerator)
1107
+
1108
+ if is_main_process:
1109
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
1110
+ save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
1111
+
1112
+ logger.info("model saved.")
1113
+
1114
+
1115
+ def setup_parser() -> argparse.ArgumentParser:
1116
+ parser = argparse.ArgumentParser()
1117
+
1118
+ add_logging_arguments(parser)
1119
+ train_util.add_sd_models_arguments(parser)
1120
+ train_util.add_dataset_arguments(parser, True, True, True)
1121
+ train_util.add_training_arguments(parser, True)
1122
+ train_util.add_masked_loss_arguments(parser)
1123
+ deepspeed_utils.add_deepspeed_arguments(parser)
1124
+ train_util.add_optimizer_arguments(parser)
1125
+ config_util.add_config_arguments(parser)
1126
+ custom_train_functions.add_custom_train_arguments(parser)
1127
+
1128
+ parser.add_argument(
1129
+ "--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
1130
+ )
1131
+ parser.add_argument(
1132
+ "--save_model_as",
1133
+ type=str,
1134
+ default="safetensors",
1135
+ choices=[None, "ckpt", "pt", "safetensors"],
1136
+ help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
1137
+ )
1138
+
1139
+ parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
1140
+ parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
1141
+
1142
+ parser.add_argument(
1143
+ "--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
1144
+ )
1145
+ parser.add_argument(
1146
+ "--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
1147
+ )
1148
+ parser.add_argument(
1149
+ "--network_dim",
1150
+ type=int,
1151
+ default=None,
1152
+ help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
1153
+ )
1154
+ parser.add_argument(
1155
+ "--network_alpha",
1156
+ type=float,
1157
+ default=1,
1158
+ help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
1159
+ )
1160
+ parser.add_argument(
1161
+ "--network_dropout",
1162
+ type=float,
1163
+ default=None,
1164
+ help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
1165
+ )
1166
+ parser.add_argument(
1167
+ "--network_args",
1168
+ type=str,
1169
+ default=None,
1170
+ nargs="*",
1171
+ help="additional arguments for network (key=value) / ネットワークへの追加の引数",
1172
+ )
1173
+ parser.add_argument(
1174
+ "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
1175
+ )
1176
+ parser.add_argument(
1177
+ "--network_train_text_encoder_only",
1178
+ action="store_true",
1179
+ help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
1180
+ )
1181
+ parser.add_argument(
1182
+ "--training_comment",
1183
+ type=str,
1184
+ default=None,
1185
+ help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
1186
+ )
1187
+ parser.add_argument(
1188
+ "--dim_from_weights",
1189
+ action="store_true",
1190
+ help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
1191
+ )
1192
+ parser.add_argument(
1193
+ "--scale_weight_norms",
1194
+ type=float,
1195
+ default=None,
1196
+ help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
1197
+ )
1198
+ parser.add_argument(
1199
+ "--base_weights",
1200
+ type=str,
1201
+ default=None,
1202
+ nargs="*",
1203
+ help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
1204
+ )
1205
+ parser.add_argument(
1206
+ "--base_weights_multiplier",
1207
+ type=float,
1208
+ default=None,
1209
+ nargs="*",
1210
+ help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
1211
+ )
1212
+ parser.add_argument(
1213
+ "--no_half_vae",
1214
+ action="store_true",
1215
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
1216
+ )
1217
+ parser.add_argument(
1218
+ "--skip_until_initial_step",
1219
+ action="store_true",
1220
+ help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
1221
+ )
1222
+ parser.add_argument(
1223
+ "--initial_epoch",
1224
+ type=int,
1225
+ default=None,
1226
+ help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
1227
+ + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
1228
+ )
1229
+ parser.add_argument(
1230
+ "--initial_step",
1231
+ type=int,
1232
+ default=None,
1233
+ help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
1234
+ + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
1235
+ )
1236
+ # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1237
+ # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1238
+ # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
1239
+ return parser
1240
+
1241
+
1242
+ if __name__ == "__main__":
1243
+ parser = setup_parser()
1244
+
1245
+ args = parser.parse_args()
1246
+ train_util.verify_command_line_training_args(args)
1247
+ args = train_util.read_config_from_file(args, parser)
1248
+
1249
+ trainer = NetworkTrainer()
1250
+ trainer.train(args)
train_textual_inversion.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ from multiprocessing import Value
5
+ import toml
6
+
7
+ from tqdm import tqdm
8
+
9
+ import torch
10
+ from library.device_utils import init_ipex, clean_memory_on_device
11
+
12
+
13
+ init_ipex()
14
+
15
+ from accelerate.utils import set_seed
16
+ from diffusers import DDPMScheduler
17
+ from transformers import CLIPTokenizer
18
+ from library import deepspeed_utils, model_util
19
+
20
+ import library.train_util as train_util
21
+ import library.huggingface_util as huggingface_util
22
+ import library.config_util as config_util
23
+ from library.config_util import (
24
+ ConfigSanitizer,
25
+ BlueprintGenerator,
26
+ )
27
+ import library.custom_train_functions as custom_train_functions
28
+ from library.custom_train_functions import (
29
+ apply_snr_weight,
30
+ prepare_scheduler_for_custom_training,
31
+ scale_v_prediction_loss_like_noise_prediction,
32
+ add_v_prediction_like_loss,
33
+ apply_debiased_estimation,
34
+ apply_masked_loss,
35
+ )
36
+ from library.utils import setup_logging, add_logging_arguments
37
+
38
+ setup_logging()
39
+ import logging
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ imagenet_templates_small = [
44
+ "a photo of a {}",
45
+ "a rendering of a {}",
46
+ "a cropped photo of the {}",
47
+ "the photo of a {}",
48
+ "a photo of a clean {}",
49
+ "a photo of a dirty {}",
50
+ "a dark photo of the {}",
51
+ "a photo of my {}",
52
+ "a photo of the cool {}",
53
+ "a close-up photo of a {}",
54
+ "a bright photo of the {}",
55
+ "a cropped photo of a {}",
56
+ "a photo of the {}",
57
+ "a good photo of the {}",
58
+ "a photo of one {}",
59
+ "a close-up photo of the {}",
60
+ "a rendition of the {}",
61
+ "a photo of the clean {}",
62
+ "a rendition of a {}",
63
+ "a photo of a nice {}",
64
+ "a good photo of a {}",
65
+ "a photo of the nice {}",
66
+ "a photo of the small {}",
67
+ "a photo of the weird {}",
68
+ "a photo of the large {}",
69
+ "a photo of a cool {}",
70
+ "a photo of a small {}",
71
+ ]
72
+
73
+ imagenet_style_templates_small = [
74
+ "a painting in the style of {}",
75
+ "a rendering in the style of {}",
76
+ "a cropped painting in the style of {}",
77
+ "the painting in the style of {}",
78
+ "a clean painting in the style of {}",
79
+ "a dirty painting in the style of {}",
80
+ "a dark painting in the style of {}",
81
+ "a picture in the style of {}",
82
+ "a cool painting in the style of {}",
83
+ "a close-up painting in the style of {}",
84
+ "a bright painting in the style of {}",
85
+ "a cropped painting in the style of {}",
86
+ "a good painting in the style of {}",
87
+ "a close-up painting in the style of {}",
88
+ "a rendition in the style of {}",
89
+ "a nice painting in the style of {}",
90
+ "a small painting in the style of {}",
91
+ "a weird painting in the style of {}",
92
+ "a large painting in the style of {}",
93
+ ]
94
+
95
+
96
+ class TextualInversionTrainer:
97
+ def __init__(self):
98
+ self.vae_scale_factor = 0.18215
99
+ self.is_sdxl = False
100
+
101
+ def assert_extra_args(self, args, train_dataset_group):
102
+ train_dataset_group.verify_bucket_reso_steps(64)
103
+
104
+ def load_target_model(self, args, weight_dtype, accelerator):
105
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
106
+ return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
107
+
108
+ def load_tokenizer(self, args):
109
+ tokenizer = train_util.load_tokenizer(args)
110
+ return tokenizer
111
+
112
+ def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
113
+ pass
114
+
115
+ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
116
+ with torch.enable_grad():
117
+ input_ids = batch["input_ids"].to(accelerator.device)
118
+ encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
119
+ return encoder_hidden_states
120
+
121
+ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
122
+ noise_pred = unet(noisy_latents, timesteps, text_conds).sample
123
+ return noise_pred
124
+
125
+ def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
126
+ train_util.sample_images(
127
+ accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
128
+ )
129
+
130
+ def save_weights(self, file, updated_embs, save_dtype, metadata):
131
+ state_dict = {"emb_params": updated_embs[0]}
132
+
133
+ if save_dtype is not None:
134
+ for key in list(state_dict.keys()):
135
+ v = state_dict[key]
136
+ v = v.detach().clone().to("cpu").to(save_dtype)
137
+ state_dict[key] = v
138
+
139
+ if os.path.splitext(file)[1] == ".safetensors":
140
+ from safetensors.torch import save_file
141
+
142
+ save_file(state_dict, file, metadata)
143
+ else:
144
+ torch.save(state_dict, file) # can be loaded in Web UI
145
+
146
+ def load_weights(self, file):
147
+ if os.path.splitext(file)[1] == ".safetensors":
148
+ from safetensors.torch import load_file
149
+
150
+ data = load_file(file)
151
+ else:
152
+ # compatible to Web UI's file format
153
+ data = torch.load(file, map_location="cpu")
154
+ if type(data) != dict:
155
+ raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
156
+
157
+ if "string_to_param" in data: # textual inversion embeddings
158
+ data = data["string_to_param"]
159
+ if hasattr(data, "_parameters"): # support old PyTorch?
160
+ data = getattr(data, "_parameters")
161
+
162
+ emb = next(iter(data.values()))
163
+ if type(emb) != torch.Tensor:
164
+ raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
165
+
166
+ if len(emb.size()) == 1:
167
+ emb = emb.unsqueeze(0)
168
+
169
+ return [emb]
170
+
171
+ def train(self, args):
172
+ if args.output_name is None:
173
+ args.output_name = args.token_string
174
+ use_template = args.use_object_template or args.use_style_template
175
+
176
+ train_util.verify_training_args(args)
177
+ train_util.prepare_dataset_args(args, True)
178
+ setup_logging(args, reset=True)
179
+
180
+ cache_latents = args.cache_latents
181
+
182
+ if args.seed is not None:
183
+ set_seed(args.seed)
184
+
185
+ tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
186
+ tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
187
+
188
+ # acceleratorを準備する
189
+ logger.info("prepare accelerator")
190
+ accelerator = train_util.prepare_accelerator(args)
191
+
192
+ # mixed precisionに対応した型を用意しておき適宜castする
193
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
194
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
195
+
196
+ # モデルを読み込む
197
+ model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
198
+ text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
199
+
200
+ if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
201
+ accelerator.print(
202
+ "accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
203
+ + "accelerateでは複数のモデル(テキストエンコーダー)のgradient_accumulation_stepsはサポートされていないようです"
204
+ )
205
+
206
+ # Convert the init_word to token_id
207
+ init_token_ids_list = []
208
+ if args.init_word is not None:
209
+ for i, tokenizer in enumerate(tokenizers):
210
+ init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
211
+ if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
212
+ accelerator.print(
213
+ f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / "
214
+ + f"初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: tokenizer {i+1}, length {len(init_token_ids)}"
215
+ )
216
+ init_token_ids_list.append(init_token_ids)
217
+ else:
218
+ init_token_ids_list = [None] * len(tokenizers)
219
+
220
+ # tokenizerに新しい単語を追加する。追加する単語の数はnum_vectors_per_token
221
+ # token_stringが hoge の場合、"hoge", "hoge1", "hoge2", ... が追加される
222
+ # add new word to tokenizer, count is num_vectors_per_token
223
+ # if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
224
+
225
+ self.assert_token_string(args.token_string, tokenizers)
226
+
227
+ token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
228
+ token_ids_list = []
229
+ token_embeds_list = []
230
+ for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
231
+ num_added_tokens = tokenizer.add_tokens(token_strings)
232
+ assert (
233
+ num_added_tokens == args.num_vectors_per_token
234
+ ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: tokenizer {i+1}, {args.token_string}"
235
+
236
+ token_ids = tokenizer.convert_tokens_to_ids(token_strings)
237
+ accelerator.print(f"tokens are added for tokenizer {i+1}: {token_ids}")
238
+ assert (
239
+ min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1
240
+ ), f"token ids is not ordered : tokenizer {i+1}, {token_ids}"
241
+ assert (
242
+ len(tokenizer) - 1 == token_ids[-1]
243
+ ), f"token ids is not end of tokenize: tokenizer {i+1}, {token_ids}, {len(tokenizer)}"
244
+ token_ids_list.append(token_ids)
245
+
246
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
247
+ text_encoder.resize_token_embeddings(len(tokenizer))
248
+
249
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
250
+ token_embeds = text_encoder.get_input_embeddings().weight.data
251
+ if init_token_ids is not None:
252
+ for i, token_id in enumerate(token_ids):
253
+ token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
254
+ # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
255
+ token_embeds_list.append(token_embeds)
256
+
257
+ # load weights
258
+ if args.weights is not None:
259
+ embeddings_list = self.load_weights(args.weights)
260
+ assert len(token_ids) == len(
261
+ embeddings_list[0]
262
+ ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
263
+ # accelerator.print(token_ids, embeddings.size())
264
+ for token_ids, embeddings, token_embeds in zip(token_ids_list, embeddings_list, token_embeds_list):
265
+ for token_id, embedding in zip(token_ids, embeddings):
266
+ token_embeds[token_id] = embedding
267
+ # accelerator.print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
268
+ accelerator.print(f"weighs loaded")
269
+
270
+ accelerator.print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
271
+
272
+ # データセットを準備する
273
+ if args.dataset_class is None:
274
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
275
+ if args.dataset_config is not None:
276
+ accelerator.print(f"Load dataset config from {args.dataset_config}")
277
+ user_config = config_util.load_user_config(args.dataset_config)
278
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
279
+ if any(getattr(args, attr) is not None for attr in ignored):
280
+ accelerator.print(
281
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
282
+ ", ".join(ignored)
283
+ )
284
+ )
285
+ else:
286
+ use_dreambooth_method = args.in_json is None
287
+ if use_dreambooth_method:
288
+ accelerator.print("Use DreamBooth method.")
289
+ user_config = {
290
+ "datasets": [
291
+ {
292
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
293
+ args.train_data_dir, args.reg_data_dir
294
+ )
295
+ }
296
+ ]
297
+ }
298
+ else:
299
+ logger.info("Train with captions.")
300
+ user_config = {
301
+ "datasets": [
302
+ {
303
+ "subsets": [
304
+ {
305
+ "image_dir": args.train_data_dir,
306
+ "metadata_file": args.in_json,
307
+ }
308
+ ]
309
+ }
310
+ ]
311
+ }
312
+
313
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
314
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
315
+ else:
316
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
317
+
318
+ self.assert_extra_args(args, train_dataset_group)
319
+
320
+ current_epoch = Value("i", 0)
321
+ current_step = Value("i", 0)
322
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
323
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
324
+
325
+ # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
326
+ if use_template:
327
+ accelerator.print(f"use template for training captions. is object: {args.use_object_template}")
328
+ templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
329
+ replace_to = " ".join(token_strings)
330
+ captions = []
331
+ for tmpl in templates:
332
+ captions.append(tmpl.format(replace_to))
333
+ train_dataset_group.add_replacement("", captions)
334
+
335
+ # サンプル生成用
336
+ if args.num_vectors_per_token > 1:
337
+ prompt_replacement = (args.token_string, replace_to)
338
+ else:
339
+ prompt_replacement = None
340
+ else:
341
+ # サンプル生成用
342
+ if args.num_vectors_per_token > 1:
343
+ replace_to = " ".join(token_strings)
344
+ train_dataset_group.add_replacement(args.token_string, replace_to)
345
+ prompt_replacement = (args.token_string, replace_to)
346
+ else:
347
+ prompt_replacement = None
348
+
349
+ if args.debug_dataset:
350
+ train_util.debug_dataset(train_dataset_group, show_input_ids=True)
351
+ return
352
+ if len(train_dataset_group) == 0:
353
+ accelerator.print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
354
+ return
355
+
356
+ if cache_latents:
357
+ assert (
358
+ train_dataset_group.is_latent_cacheable()
359
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
360
+
361
+ # モデルに xformers とか memory efficient attention を組み込む
362
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
363
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
364
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
365
+
366
+ # 学習を準備する
367
+ if cache_latents:
368
+ vae.to(accelerator.device, dtype=vae_dtype)
369
+ vae.requires_grad_(False)
370
+ vae.eval()
371
+ with torch.no_grad():
372
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
373
+ vae.to("cpu")
374
+ clean_memory_on_device(accelerator.device)
375
+
376
+ accelerator.wait_for_everyone()
377
+
378
+ if args.gradient_checkpointing:
379
+ unet.enable_gradient_checkpointing()
380
+ for text_encoder in text_encoders:
381
+ text_encoder.gradient_checkpointing_enable()
382
+
383
+ # 学習に必要なクラスを準備する
384
+ accelerator.print("prepare optimizer, data loader etc.")
385
+ trainable_params = []
386
+ for text_encoder in text_encoders:
387
+ trainable_params += text_encoder.get_input_embeddings().parameters()
388
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
389
+
390
+ # dataloaderを準備する
391
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
392
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
393
+ train_dataloader = torch.utils.data.DataLoader(
394
+ train_dataset_group,
395
+ batch_size=1,
396
+ shuffle=True,
397
+ collate_fn=collator,
398
+ num_workers=n_workers,
399
+ persistent_workers=args.persistent_data_loader_workers,
400
+ )
401
+
402
+ # 学習ステップ数を計算する
403
+ if args.max_train_epochs is not None:
404
+ args.max_train_steps = args.max_train_epochs * math.ceil(
405
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
406
+ )
407
+ accelerator.print(
408
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
409
+ )
410
+
411
+ # データセット側にも学習ステップを送信
412
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
413
+
414
+ # lr schedulerを用意する
415
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
416
+
417
+ # acceleratorがなんかよろしくやってくれるらしい
418
+ if len(text_encoders) == 1:
419
+ text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
420
+ text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
421
+ )
422
+
423
+ elif len(text_encoders) == 2:
424
+ text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
425
+ text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
426
+ )
427
+
428
+ text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
429
+
430
+ else:
431
+ raise NotImplementedError()
432
+
433
+ index_no_updates_list = []
434
+ orig_embeds_params_list = []
435
+ for tokenizer, token_ids, text_encoder in zip(tokenizers, token_ids_list, text_encoders):
436
+ index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
437
+ index_no_updates_list.append(index_no_updates)
438
+
439
+ # accelerator.print(len(index_no_updates), torch.sum(index_no_updates))
440
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
441
+ orig_embeds_params_list.append(orig_embeds_params)
442
+
443
+ # Freeze all parameters except for the token embeddings in text encoder
444
+ text_encoder.requires_grad_(True)
445
+ unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
446
+ unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
447
+ unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
448
+ unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
449
+ # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
450
+
451
+ unet.requires_grad_(False)
452
+ unet.to(accelerator.device, dtype=weight_dtype)
453
+ if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
454
+ # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す
455
+ unet.train()
456
+ else:
457
+ unet.eval()
458
+
459
+ if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
460
+ vae.requires_grad_(False)
461
+ vae.eval()
462
+ vae.to(accelerator.device, dtype=vae_dtype)
463
+
464
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
465
+ if args.full_fp16:
466
+ train_util.patch_accelerator_for_fp16_training(accelerator)
467
+ for text_encoder in text_encoders:
468
+ text_encoder.to(weight_dtype)
469
+ if args.full_bf16:
470
+ for text_encoder in text_encoders:
471
+ text_encoder.to(weight_dtype)
472
+
473
+ # resumeする
474
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
475
+
476
+ # epoch数を計算する
477
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
478
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
479
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
480
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
481
+
482
+ # 学習する
483
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
484
+ accelerator.print("running training / 学習開始")
485
+ accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
486
+ accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
487
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
488
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
489
+ accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
490
+ accelerator.print(
491
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
492
+ )
493
+ accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
494
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
495
+
496
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
497
+ global_step = 0
498
+
499
+ noise_scheduler = DDPMScheduler(
500
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
501
+ )
502
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
503
+ if args.zero_terminal_snr:
504
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
505
+
506
+ if accelerator.is_main_process:
507
+ init_kwargs = {}
508
+ if args.wandb_run_name:
509
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
510
+ if args.log_tracker_config is not None:
511
+ init_kwargs = toml.load(args.log_tracker_config)
512
+ accelerator.init_trackers(
513
+ "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
514
+ )
515
+
516
+ # function for saving/removing
517
+ def save_model(ckpt_name, embs_list, steps, epoch_no, force_sync_upload=False):
518
+ os.makedirs(args.output_dir, exist_ok=True)
519
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
520
+
521
+ accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
522
+
523
+ sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, False, True)
524
+
525
+ self.save_weights(ckpt_file, embs_list, save_dtype, sai_metadata)
526
+ if args.huggingface_repo_id is not None:
527
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
528
+
529
+ def remove_model(old_ckpt_name):
530
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
531
+ if os.path.exists(old_ckpt_file):
532
+ accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
533
+ os.remove(old_ckpt_file)
534
+
535
+ # For --sample_at_first
536
+ self.sample_images(
537
+ accelerator,
538
+ args,
539
+ 0,
540
+ global_step,
541
+ accelerator.device,
542
+ vae,
543
+ tokenizer_or_list,
544
+ text_encoder_or_list,
545
+ unet,
546
+ prompt_replacement,
547
+ )
548
+
549
+ # training loop
550
+ for epoch in range(num_train_epochs):
551
+ accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
552
+ current_epoch.value = epoch + 1
553
+
554
+ for text_encoder in text_encoders:
555
+ text_encoder.train()
556
+
557
+ loss_total = 0
558
+
559
+ for step, batch in enumerate(train_dataloader):
560
+ current_step.value = global_step
561
+ with accelerator.accumulate(text_encoders[0]):
562
+ with torch.no_grad():
563
+ if "latents" in batch and batch["latents"] is not None:
564
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
565
+ else:
566
+ # latentに変換
567
+ latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
568
+ latents = latents * self.vae_scale_factor
569
+
570
+ # Get the text embedding for conditioning
571
+ text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
572
+
573
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
574
+ # with noise offset and/or multires noise if specified
575
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
576
+ args, noise_scheduler, latents
577
+ )
578
+
579
+ # Predict the noise residual
580
+ with accelerator.autocast():
581
+ noise_pred = self.call_unet(
582
+ args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
583
+ )
584
+
585
+ if args.v_parameterization:
586
+ # v-parameterization training
587
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
588
+ else:
589
+ target = noise
590
+
591
+ loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
592
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
593
+ loss = apply_masked_loss(loss, batch)
594
+ loss = loss.mean([1, 2, 3])
595
+
596
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
597
+ loss = loss * loss_weights
598
+
599
+ if args.min_snr_gamma:
600
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
601
+ if args.scale_v_pred_loss_like_noise_pred:
602
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
603
+ if args.v_pred_like_loss:
604
+ loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
605
+ if args.debiased_estimation_loss:
606
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
607
+
608
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
609
+
610
+ accelerator.backward(loss)
611
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
612
+ params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
613
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
614
+
615
+ optimizer.step()
616
+ lr_scheduler.step()
617
+ optimizer.zero_grad(set_to_none=True)
618
+
619
+ # Let's make sure we don't update any embedding weights besides the newly added token
620
+ with torch.no_grad():
621
+ for text_encoder, orig_embeds_params, index_no_updates in zip(
622
+ text_encoders, orig_embeds_params_list, index_no_updates_list
623
+ ):
624
+ # if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
625
+ input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
626
+ input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
627
+ index_no_updates
628
+ ]
629
+
630
+ # Checks if the accelerator has performed an optimization step behind the scenes
631
+ if accelerator.sync_gradients:
632
+ progress_bar.update(1)
633
+ global_step += 1
634
+
635
+ self.sample_images(
636
+ accelerator,
637
+ args,
638
+ None,
639
+ global_step,
640
+ accelerator.device,
641
+ vae,
642
+ tokenizer_or_list,
643
+ text_encoder_or_list,
644
+ unet,
645
+ prompt_replacement,
646
+ )
647
+
648
+ # 指定ステップごとにモデルを保存
649
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
650
+ accelerator.wait_for_everyone()
651
+ if accelerator.is_main_process:
652
+ updated_embs_list = []
653
+ for text_encoder, token_ids in zip(text_encoders, token_ids_list):
654
+ updated_embs = (
655
+ accelerator.unwrap_model(text_encoder)
656
+ .get_input_embeddings()
657
+ .weight[token_ids]
658
+ .data.detach()
659
+ .clone()
660
+ )
661
+ updated_embs_list.append(updated_embs)
662
+
663
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
664
+ save_model(ckpt_name, updated_embs_list, global_step, epoch)
665
+
666
+ if args.save_state:
667
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
668
+
669
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
670
+ if remove_step_no is not None:
671
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
672
+ remove_model(remove_ckpt_name)
673
+
674
+ current_loss = loss.detach().item()
675
+ if args.logging_dir is not None:
676
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
677
+ if (
678
+ args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
679
+ ): # tracking d*lr value
680
+ logs["lr/d*lr"] = (
681
+ lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
682
+ )
683
+ accelerator.log(logs, step=global_step)
684
+
685
+ loss_total += current_loss
686
+ avr_loss = loss_total / (step + 1)
687
+ logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
688
+ progress_bar.set_postfix(**logs)
689
+
690
+ if global_step >= args.max_train_steps:
691
+ break
692
+
693
+ if args.logging_dir is not None:
694
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
695
+ accelerator.log(logs, step=epoch + 1)
696
+
697
+ accelerator.wait_for_everyone()
698
+
699
+ updated_embs_list = []
700
+ for text_encoder, token_ids in zip(text_encoders, token_ids_list):
701
+ updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
702
+ updated_embs_list.append(updated_embs)
703
+
704
+ if args.save_every_n_epochs is not None:
705
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
706
+ if accelerator.is_main_process and saving:
707
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
708
+ save_model(ckpt_name, updated_embs_list, epoch + 1, global_step)
709
+
710
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
711
+ if remove_epoch_no is not None:
712
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
713
+ remove_model(remove_ckpt_name)
714
+
715
+ if args.save_state:
716
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
717
+
718
+ self.sample_images(
719
+ accelerator,
720
+ args,
721
+ epoch + 1,
722
+ global_step,
723
+ accelerator.device,
724
+ vae,
725
+ tokenizer_or_list,
726
+ text_encoder_or_list,
727
+ unet,
728
+ prompt_replacement,
729
+ )
730
+
731
+ # end of epoch
732
+
733
+ is_main_process = accelerator.is_main_process
734
+ if is_main_process:
735
+ text_encoder = accelerator.unwrap_model(text_encoder)
736
+ updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
737
+
738
+ accelerator.end_training()
739
+
740
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
741
+ train_util.save_state_on_train_end(args, accelerator)
742
+
743
+ if is_main_process:
744
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
745
+ save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
746
+
747
+ logger.info("model saved.")
748
+
749
+
750
+ def setup_parser() -> argparse.ArgumentParser:
751
+ parser = argparse.ArgumentParser()
752
+
753
+ add_logging_arguments(parser)
754
+ train_util.add_sd_models_arguments(parser)
755
+ train_util.add_dataset_arguments(parser, True, True, False)
756
+ train_util.add_training_arguments(parser, True)
757
+ train_util.add_masked_loss_arguments(parser)
758
+ deepspeed_utils.add_deepspeed_arguments(parser)
759
+ train_util.add_optimizer_arguments(parser)
760
+ config_util.add_config_arguments(parser)
761
+ custom_train_functions.add_custom_train_arguments(parser, False)
762
+
763
+ parser.add_argument(
764
+ "--save_model_as",
765
+ type=str,
766
+ default="pt",
767
+ choices=[None, "ckpt", "pt", "safetensors"],
768
+ help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
769
+ )
770
+
771
+ parser.add_argument(
772
+ "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
773
+ )
774
+ parser.add_argument(
775
+ "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
776
+ )
777
+ parser.add_argument(
778
+ "--token_string",
779
+ type=str,
780
+ default=None,
781
+ help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
782
+ )
783
+ parser.add_argument(
784
+ "--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
785
+ )
786
+ parser.add_argument(
787
+ "--use_object_template",
788
+ action="store_true",
789
+ help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
790
+ )
791
+ parser.add_argument(
792
+ "--use_style_template",
793
+ action="store_true",
794
+ help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
795
+ )
796
+ parser.add_argument(
797
+ "--no_half_vae",
798
+ action="store_true",
799
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
800
+ )
801
+
802
+ return parser
803
+
804
+
805
+ if __name__ == "__main__":
806
+ parser = setup_parser()
807
+
808
+ args = parser.parse_args()
809
+ train_util.verify_command_line_training_args(args)
810
+ args = train_util.read_config_from_file(args, parser)
811
+
812
+ trainer = TextualInversionTrainer()
813
+ trainer.train(args)
train_textual_inversion_XTI.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import argparse
3
+ import math
4
+ import os
5
+ import toml
6
+ from multiprocessing import Value
7
+
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ from library import deepspeed_utils
12
+ from library.device_utils import init_ipex, clean_memory_on_device
13
+
14
+ init_ipex()
15
+
16
+ from accelerate.utils import set_seed
17
+ import diffusers
18
+ from diffusers import DDPMScheduler
19
+ import library
20
+
21
+ import library.train_util as train_util
22
+ import library.huggingface_util as huggingface_util
23
+ import library.config_util as config_util
24
+ from library.config_util import (
25
+ ConfigSanitizer,
26
+ BlueprintGenerator,
27
+ )
28
+ import library.custom_train_functions as custom_train_functions
29
+ from library.custom_train_functions import (
30
+ apply_snr_weight,
31
+ prepare_scheduler_for_custom_training,
32
+ pyramid_noise_like,
33
+ apply_noise_offset,
34
+ scale_v_prediction_loss_like_noise_prediction,
35
+ apply_debiased_estimation,
36
+ apply_masked_loss,
37
+ )
38
+ import library.original_unet as original_unet
39
+ from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
40
+ from library.utils import setup_logging, add_logging_arguments
41
+
42
+ setup_logging()
43
+ import logging
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ imagenet_templates_small = [
48
+ "a photo of a {}",
49
+ "a rendering of a {}",
50
+ "a cropped photo of the {}",
51
+ "the photo of a {}",
52
+ "a photo of a clean {}",
53
+ "a photo of a dirty {}",
54
+ "a dark photo of the {}",
55
+ "a photo of my {}",
56
+ "a photo of the cool {}",
57
+ "a close-up photo of a {}",
58
+ "a bright photo of the {}",
59
+ "a cropped photo of a {}",
60
+ "a photo of the {}",
61
+ "a good photo of the {}",
62
+ "a photo of one {}",
63
+ "a close-up photo of the {}",
64
+ "a rendition of the {}",
65
+ "a photo of the clean {}",
66
+ "a rendition of a {}",
67
+ "a photo of a nice {}",
68
+ "a good photo of a {}",
69
+ "a photo of the nice {}",
70
+ "a photo of the small {}",
71
+ "a photo of the weird {}",
72
+ "a photo of the large {}",
73
+ "a photo of a cool {}",
74
+ "a photo of a small {}",
75
+ ]
76
+
77
+ imagenet_style_templates_small = [
78
+ "a painting in the style of {}",
79
+ "a rendering in the style of {}",
80
+ "a cropped painting in the style of {}",
81
+ "the painting in the style of {}",
82
+ "a clean painting in the style of {}",
83
+ "a dirty painting in the style of {}",
84
+ "a dark painting in the style of {}",
85
+ "a picture in the style of {}",
86
+ "a cool painting in the style of {}",
87
+ "a close-up painting in the style of {}",
88
+ "a bright painting in the style of {}",
89
+ "a cropped painting in the style of {}",
90
+ "a good painting in the style of {}",
91
+ "a close-up painting in the style of {}",
92
+ "a rendition in the style of {}",
93
+ "a nice painting in the style of {}",
94
+ "a small painting in the style of {}",
95
+ "a weird painting in the style of {}",
96
+ "a large painting in the style of {}",
97
+ ]
98
+
99
+
100
+ def train(args):
101
+ if args.output_name is None:
102
+ args.output_name = args.token_string
103
+ use_template = args.use_object_template or args.use_style_template
104
+ setup_logging(args, reset=True)
105
+
106
+ train_util.verify_training_args(args)
107
+ train_util.prepare_dataset_args(args, True)
108
+
109
+ if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
110
+ logger.warning(
111
+ "sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
112
+ )
113
+ assert (
114
+ args.dataset_class is None
115
+ ), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
116
+
117
+ cache_latents = args.cache_latents
118
+
119
+ if args.seed is not None:
120
+ set_seed(args.seed)
121
+
122
+ tokenizer = train_util.load_tokenizer(args)
123
+
124
+ # acceleratorを準備する
125
+ logger.info("prepare accelerator")
126
+ accelerator = train_util.prepare_accelerator(args)
127
+
128
+ # mixed precisionに対応した型を用意しておき適宜castする
129
+ weight_dtype, save_dtype = train_util.prepare_dtype(args)
130
+
131
+ # モデルを読み込む
132
+ text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
133
+
134
+ # Convert the init_word to token_id
135
+ if args.init_word is not None:
136
+ init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
137
+ if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
138
+ logger.warning(
139
+ f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
140
+ )
141
+ else:
142
+ init_token_ids = None
143
+
144
+ # add new word to tokenizer, count is num_vectors_per_token
145
+ token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
146
+ num_added_tokens = tokenizer.add_tokens(token_strings)
147
+ assert (
148
+ num_added_tokens == args.num_vectors_per_token
149
+ ), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
150
+
151
+ token_ids = tokenizer.convert_tokens_to_ids(token_strings)
152
+ logger.info(f"tokens are added: {token_ids}")
153
+ assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
154
+ assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
155
+
156
+ token_strings_XTI = []
157
+ XTI_layers = [
158
+ "IN01",
159
+ "IN02",
160
+ "IN04",
161
+ "IN05",
162
+ "IN07",
163
+ "IN08",
164
+ "MID",
165
+ "OUT03",
166
+ "OUT04",
167
+ "OUT05",
168
+ "OUT06",
169
+ "OUT07",
170
+ "OUT08",
171
+ "OUT09",
172
+ "OUT10",
173
+ "OUT11",
174
+ ]
175
+ for layer_name in XTI_layers:
176
+ token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
177
+
178
+ tokenizer.add_tokens(token_strings_XTI)
179
+ token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
180
+ logger.info(f"tokens are added (XTI): {token_ids_XTI}")
181
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
182
+ text_encoder.resize_token_embeddings(len(tokenizer))
183
+
184
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
185
+ token_embeds = text_encoder.get_input_embeddings().weight.data
186
+ if init_token_ids is not None:
187
+ for i, token_id in enumerate(token_ids_XTI):
188
+ token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
189
+ # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
190
+
191
+ # load weights
192
+ if args.weights is not None:
193
+ embeddings = load_weights(args.weights)
194
+ assert len(token_ids) == len(
195
+ embeddings
196
+ ), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
197
+ # logger.info(token_ids, embeddings.size())
198
+ for token_id, embedding in zip(token_ids_XTI, embeddings):
199
+ token_embeds[token_id] = embedding
200
+ # logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
201
+ logger.info(f"weighs loaded")
202
+
203
+ logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
204
+
205
+ # データセットを準備する
206
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
207
+ if args.dataset_config is not None:
208
+ logger.info(f"Load dataset config from {args.dataset_config}")
209
+ user_config = config_util.load_user_config(args.dataset_config)
210
+ ignored = ["train_data_dir", "reg_data_dir", "in_json"]
211
+ if any(getattr(args, attr) is not None for attr in ignored):
212
+ logger.info(
213
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
214
+ ", ".join(ignored)
215
+ )
216
+ )
217
+ else:
218
+ use_dreambooth_method = args.in_json is None
219
+ if use_dreambooth_method:
220
+ logger.info("Use DreamBooth method.")
221
+ user_config = {
222
+ "datasets": [
223
+ {"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
224
+ ]
225
+ }
226
+ else:
227
+ logger.info("Train with captions.")
228
+ user_config = {
229
+ "datasets": [
230
+ {
231
+ "subsets": [
232
+ {
233
+ "image_dir": args.train_data_dir,
234
+ "metadata_file": args.in_json,
235
+ }
236
+ ]
237
+ }
238
+ ]
239
+ }
240
+
241
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
242
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
243
+ train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
244
+ current_epoch = Value("i", 0)
245
+ current_step = Value("i", 0)
246
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
247
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
248
+
249
+ # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
250
+ if use_template:
251
+ logger.info(f"use template for training captions. is object: {args.use_object_template}")
252
+ templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
253
+ replace_to = " ".join(token_strings)
254
+ captions = []
255
+ for tmpl in templates:
256
+ captions.append(tmpl.format(replace_to))
257
+ train_dataset_group.add_replacement("", captions)
258
+
259
+ if args.num_vectors_per_token > 1:
260
+ prompt_replacement = (args.token_string, replace_to)
261
+ else:
262
+ prompt_replacement = None
263
+ else:
264
+ if args.num_vectors_per_token > 1:
265
+ replace_to = " ".join(token_strings)
266
+ train_dataset_group.add_replacement(args.token_string, replace_to)
267
+ prompt_replacement = (args.token_string, replace_to)
268
+ else:
269
+ prompt_replacement = None
270
+
271
+ if args.debug_dataset:
272
+ train_util.debug_dataset(train_dataset_group, show_input_ids=True)
273
+ return
274
+ if len(train_dataset_group) == 0:
275
+ logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
276
+ return
277
+
278
+ if cache_latents:
279
+ assert (
280
+ train_dataset_group.is_latent_cacheable()
281
+ ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
282
+
283
+ # モデルに xformers とか memory efficient attention を組み込む
284
+ train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
285
+ original_unet.UNet2DConditionModel.forward = unet_forward_XTI
286
+ original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
287
+ original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
288
+
289
+ # 学習を準備する
290
+ if cache_latents:
291
+ vae.to(accelerator.device, dtype=weight_dtype)
292
+ vae.requires_grad_(False)
293
+ vae.eval()
294
+ with torch.no_grad():
295
+ train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
296
+ vae.to("cpu")
297
+ clean_memory_on_device(accelerator.device)
298
+
299
+ accelerator.wait_for_everyone()
300
+
301
+ if args.gradient_checkpointing:
302
+ unet.enable_gradient_checkpointing()
303
+ text_encoder.gradient_checkpointing_enable()
304
+
305
+ # 学習に必要なクラスを準備する
306
+ logger.info("prepare optimizer, data loader etc.")
307
+ trainable_params = text_encoder.get_input_embeddings().parameters()
308
+ _, _, optimizer = train_util.get_optimizer(args, trainable_params)
309
+
310
+ # dataloaderを準備する
311
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
312
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
313
+ train_dataloader = torch.utils.data.DataLoader(
314
+ train_dataset_group,
315
+ batch_size=1,
316
+ shuffle=True,
317
+ collate_fn=collator,
318
+ num_workers=n_workers,
319
+ persistent_workers=args.persistent_data_loader_workers,
320
+ )
321
+
322
+ # 学習ステップ数を計算する
323
+ if args.max_train_epochs is not None:
324
+ args.max_train_steps = args.max_train_epochs * math.ceil(
325
+ len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
326
+ )
327
+ logger.info(
328
+ f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
329
+ )
330
+
331
+ # データセット側にも学習ステップを送信
332
+ train_dataset_group.set_max_train_steps(args.max_train_steps)
333
+
334
+ # lr schedulerを用意する
335
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
336
+
337
+ # acceleratorがなんかよろしくやってくれるらしい
338
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
339
+ text_encoder, optimizer, train_dataloader, lr_scheduler
340
+ )
341
+
342
+ index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
343
+ # logger.info(len(index_no_updates), torch.sum(index_no_updates))
344
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
345
+
346
+ # Freeze all parameters except for the token embeddings in text encoder
347
+ text_encoder.requires_grad_(True)
348
+ text_encoder.text_model.encoder.requires_grad_(False)
349
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
350
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
351
+ # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
352
+
353
+ unet.requires_grad_(False)
354
+ unet.to(accelerator.device, dtype=weight_dtype)
355
+ if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
356
+ unet.train()
357
+ else:
358
+ unet.eval()
359
+
360
+ if not cache_latents:
361
+ vae.requires_grad_(False)
362
+ vae.eval()
363
+ vae.to(accelerator.device, dtype=weight_dtype)
364
+
365
+ # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
366
+ if args.full_fp16:
367
+ train_util.patch_accelerator_for_fp16_training(accelerator)
368
+ text_encoder.to(weight_dtype)
369
+
370
+ # resumeする
371
+ train_util.resume_from_local_or_hf_if_specified(accelerator, args)
372
+
373
+ # epoch数を計算する
374
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
375
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
376
+ if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
377
+ args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
378
+
379
+ # 学習する
380
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
381
+ logger.info("running training / 学習開始")
382
+ logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
383
+ logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
384
+ logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
385
+ logger.info(f" num epochs / epoch数: {num_train_epochs}")
386
+ logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
387
+ logger.info(
388
+ f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
389
+ )
390
+ logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
391
+ logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
392
+
393
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
394
+ global_step = 0
395
+
396
+ noise_scheduler = DDPMScheduler(
397
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
398
+ )
399
+ prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
400
+ if args.zero_terminal_snr:
401
+ custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
402
+
403
+ if accelerator.is_main_process:
404
+ init_kwargs = {}
405
+ if args.wandb_run_name:
406
+ init_kwargs["wandb"] = {"name": args.wandb_run_name}
407
+ if args.log_tracker_config is not None:
408
+ init_kwargs = toml.load(args.log_tracker_config)
409
+ accelerator.init_trackers(
410
+ "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
411
+ )
412
+
413
+ # function for saving/removing
414
+ def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
415
+ os.makedirs(args.output_dir, exist_ok=True)
416
+ ckpt_file = os.path.join(args.output_dir, ckpt_name)
417
+
418
+ logger.info("")
419
+ logger.info(f"saving checkpoint: {ckpt_file}")
420
+ save_weights(ckpt_file, embs, save_dtype)
421
+ if args.huggingface_repo_id is not None:
422
+ huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
423
+
424
+ def remove_model(old_ckpt_name):
425
+ old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
426
+ if os.path.exists(old_ckpt_file):
427
+ logger.info(f"removing old checkpoint: {old_ckpt_file}")
428
+ os.remove(old_ckpt_file)
429
+
430
+ # training loop
431
+ for epoch in range(num_train_epochs):
432
+ logger.info("")
433
+ logger.info(f"epoch {epoch+1}/{num_train_epochs}")
434
+ current_epoch.value = epoch + 1
435
+
436
+ text_encoder.train()
437
+
438
+ loss_total = 0
439
+
440
+ for step, batch in enumerate(train_dataloader):
441
+ current_step.value = global_step
442
+ with accelerator.accumulate(text_encoder):
443
+ with torch.no_grad():
444
+ if "latents" in batch and batch["latents"] is not None:
445
+ latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
446
+ else:
447
+ # latentに変換
448
+ latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
449
+ latents = latents * 0.18215
450
+ b_size = latents.shape[0]
451
+
452
+ # Get the text embedding for conditioning
453
+ input_ids = batch["input_ids"].to(accelerator.device)
454
+ # weight_dtype) use float instead of fp16/bf16 because text encoder is float
455
+ encoder_hidden_states = torch.stack(
456
+ [
457
+ train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
458
+ for s in torch.split(input_ids, 1, dim=1)
459
+ ]
460
+ )
461
+
462
+ # Sample noise, sample a random timestep for each image, and add noise to the latents,
463
+ # with noise offset and/or multires noise if specified
464
+ noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
465
+
466
+ # Predict the noise residual
467
+ with accelerator.autocast():
468
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
469
+
470
+ if args.v_parameterization:
471
+ # v-parameterization training
472
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
473
+ else:
474
+ target = noise
475
+
476
+ loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
477
+ if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
478
+ loss = apply_masked_loss(loss, batch)
479
+ loss = loss.mean([1, 2, 3])
480
+
481
+ loss_weights = batch["loss_weights"] # 各sampleごとのweight
482
+
483
+ loss = loss * loss_weights
484
+ if args.min_snr_gamma:
485
+ loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
486
+ if args.scale_v_pred_loss_like_noise_pred:
487
+ loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
488
+ if args.debiased_estimation_loss:
489
+ loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
490
+
491
+ loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
492
+
493
+ accelerator.backward(loss)
494
+ if accelerator.sync_gradients and args.max_grad_norm != 0.0:
495
+ params_to_clip = text_encoder.get_input_embeddings().parameters()
496
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
497
+
498
+ optimizer.step()
499
+ lr_scheduler.step()
500
+ optimizer.zero_grad(set_to_none=True)
501
+
502
+ # Let's make sure we don't update any embedding weights besides the newly added token
503
+ with torch.no_grad():
504
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
505
+ index_no_updates
506
+ ]
507
+
508
+ # Checks if the accelerator has performed an optimization step behind the scenes
509
+ if accelerator.sync_gradients:
510
+ progress_bar.update(1)
511
+ global_step += 1
512
+ # TODO: fix sample_images
513
+ # train_util.sample_images(
514
+ # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
515
+ # )
516
+
517
+ # 指定ステップごとにモデルを保存
518
+ if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
519
+ accelerator.wait_for_everyone()
520
+ if accelerator.is_main_process:
521
+ updated_embs = (
522
+ accelerator.unwrap_model(text_encoder)
523
+ .get_input_embeddings()
524
+ .weight[token_ids_XTI]
525
+ .data.detach()
526
+ .clone()
527
+ )
528
+
529
+ ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
530
+ save_model(ckpt_name, updated_embs, global_step, epoch)
531
+
532
+ if args.save_state:
533
+ train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
534
+
535
+ remove_step_no = train_util.get_remove_step_no(args, global_step)
536
+ if remove_step_no is not None:
537
+ remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
538
+ remove_model(remove_ckpt_name)
539
+
540
+ current_loss = loss.detach().item()
541
+ if args.logging_dir is not None:
542
+ logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
543
+ if (
544
+ args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
545
+ ): # tracking d*lr value
546
+ logs["lr/d*lr"] = (
547
+ lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
548
+ )
549
+ accelerator.log(logs, step=global_step)
550
+
551
+ loss_total += current_loss
552
+ avr_loss = loss_total / (step + 1)
553
+ logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
554
+ progress_bar.set_postfix(**logs)
555
+
556
+ if global_step >= args.max_train_steps:
557
+ break
558
+
559
+ if args.logging_dir is not None:
560
+ logs = {"loss/epoch": loss_total / len(train_dataloader)}
561
+ accelerator.log(logs, step=epoch + 1)
562
+
563
+ accelerator.wait_for_everyone()
564
+
565
+ updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
566
+
567
+ if args.save_every_n_epochs is not None:
568
+ saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
569
+ if accelerator.is_main_process and saving:
570
+ ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
571
+ save_model(ckpt_name, updated_embs, epoch + 1, global_step)
572
+
573
+ remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
574
+ if remove_epoch_no is not None:
575
+ remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
576
+ remove_model(remove_ckpt_name)
577
+
578
+ if args.save_state:
579
+ train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
580
+
581
+ # TODO: fix sample_images
582
+ # train_util.sample_images(
583
+ # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
584
+ # )
585
+
586
+ # end of epoch
587
+
588
+ is_main_process = accelerator.is_main_process
589
+ if is_main_process:
590
+ text_encoder = accelerator.unwrap_model(text_encoder)
591
+
592
+ accelerator.end_training()
593
+
594
+ if is_main_process and (args.save_state or args.save_state_on_train_end):
595
+ train_util.save_state_on_train_end(args, accelerator)
596
+
597
+ updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
598
+
599
+ del accelerator # この後メモリを使うのでこれは消す
600
+
601
+ if is_main_process:
602
+ ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
603
+ save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
604
+
605
+ logger.info("model saved.")
606
+
607
+
608
+ def save_weights(file, updated_embs, save_dtype):
609
+ updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
610
+ updated_embs = updated_embs.chunk(16)
611
+ XTI_layers = [
612
+ "IN01",
613
+ "IN02",
614
+ "IN04",
615
+ "IN05",
616
+ "IN07",
617
+ "IN08",
618
+ "MID",
619
+ "OUT03",
620
+ "OUT04",
621
+ "OUT05",
622
+ "OUT06",
623
+ "OUT07",
624
+ "OUT08",
625
+ "OUT09",
626
+ "OUT10",
627
+ "OUT11",
628
+ ]
629
+ state_dict = {}
630
+ for i, layer_name in enumerate(XTI_layers):
631
+ state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
632
+
633
+ # if save_dtype is not None:
634
+ # for key in list(state_dict.keys()):
635
+ # v = state_dict[key]
636
+ # v = v.detach().clone().to("cpu").to(save_dtype)
637
+ # state_dict[key] = v
638
+
639
+ if os.path.splitext(file)[1] == ".safetensors":
640
+ from safetensors.torch import save_file
641
+
642
+ save_file(state_dict, file)
643
+ else:
644
+ torch.save(state_dict, file) # can be loaded in Web UI
645
+
646
+
647
+ def load_weights(file):
648
+ if os.path.splitext(file)[1] == ".safetensors":
649
+ from safetensors.torch import load_file
650
+
651
+ data = load_file(file)
652
+ else:
653
+ raise ValueError(f"NOT XTI: {file}")
654
+
655
+ if len(data.values()) != 16:
656
+ raise ValueError(f"NOT XTI: {file}")
657
+
658
+ emb = torch.concat([x for x in data.values()])
659
+
660
+ return emb
661
+
662
+
663
+ def setup_parser() -> argparse.ArgumentParser:
664
+ parser = argparse.ArgumentParser()
665
+
666
+ add_logging_arguments(parser)
667
+ train_util.add_sd_models_arguments(parser)
668
+ train_util.add_dataset_arguments(parser, True, True, False)
669
+ train_util.add_training_arguments(parser, True)
670
+ train_util.add_masked_loss_arguments(parser)
671
+ deepspeed_utils.add_deepspeed_arguments(parser)
672
+ train_util.add_optimizer_arguments(parser)
673
+ config_util.add_config_arguments(parser)
674
+ custom_train_functions.add_custom_train_arguments(parser, False)
675
+
676
+ parser.add_argument(
677
+ "--save_model_as",
678
+ type=str,
679
+ default="pt",
680
+ choices=[None, "ckpt", "pt", "safetensors"],
681
+ help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
682
+ )
683
+
684
+ parser.add_argument(
685
+ "--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
686
+ )
687
+ parser.add_argument(
688
+ "--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
689
+ )
690
+ parser.add_argument(
691
+ "--token_string",
692
+ type=str,
693
+ default=None,
694
+ help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
695
+ )
696
+ parser.add_argument(
697
+ "--init_word", type=str, default=None, help="words to initialize vector / ベ���トルを初期化に使用する単語、複数可"
698
+ )
699
+ parser.add_argument(
700
+ "--use_object_template",
701
+ action="store_true",
702
+ help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
703
+ )
704
+ parser.add_argument(
705
+ "--use_style_template",
706
+ action="store_true",
707
+ help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
708
+ )
709
+
710
+ return parser
711
+
712
+
713
+ if __name__ == "__main__":
714
+ parser = setup_parser()
715
+
716
+ args = parser.parse_args()
717
+ train_util.verify_command_line_training_args(args)
718
+ args = train_util.read_config_from_file(args, parser)
719
+
720
+ train(args)