Spaces:
Runtime error
Runtime error
Carlexxx
commited on
Commit
·
3470339
1
Parent(s):
7bdd354
aduc-sdr
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- NOTICE.md +76 -0
- README.md +204 -7
- configs/ltxv-13b-0.9.7-dev.yaml +34 -0
- configs/ltxv-13b-0.9.7-distilled.yaml +28 -0
- configs/ltxv-13b-0.9.8-dev-fp8.yaml +34 -0
- configs/ltxv-13b-0.9.8-dev.yaml +34 -0
- configs/ltxv-13b-0.9.8-distilled-fp8.yaml +29 -0
- configs/ltxv-13b-0.9.8-distilled.yaml +29 -0
- configs/ltxv-2b-0.9.1.yaml +17 -0
- configs/ltxv-2b-0.9.5.yaml +17 -0
- configs/ltxv-2b-0.9.6-dev.yaml +17 -0
- configs/ltxv-2b-0.9.6-distilled.yaml +16 -0
- configs/ltxv-2b-0.9.8-distilled-fp8.yaml +28 -0
- configs/ltxv-2b-0.9.8-distilled.yaml +28 -0
- configs/ltxv-2b-0.9.yaml +17 -0
- deformes4D_engine.py +292 -0
- dreamo/LICENSE.txt +201 -0
- dreamo/README.md +135 -0
- dreamo/dreamo_pipeline.py +507 -0
- dreamo/transformer.py +187 -0
- dreamo/utils.py +232 -0
- flux_kontext_helpers.py +151 -0
- gemini_helpers.py +257 -0
- hardware_manager.py +35 -0
- i18n.json +128 -0
- image_specialist.py +98 -0
- inference.py +774 -0
- ltx_manager_helpers.py +198 -0
- ltx_video/LICENSE.txt +201 -0
- ltx_video/README.md +135 -0
- ltx_video/__init__.py +0 -0
- ltx_video/models/__init__.py +0 -0
- ltx_video/models/autoencoders/__init__.py +0 -0
- ltx_video/models/autoencoders/causal_conv3d.py +63 -0
- ltx_video/models/autoencoders/causal_video_autoencoder.py +1398 -0
- ltx_video/models/autoencoders/conv_nd_factory.py +90 -0
- ltx_video/models/autoencoders/dual_conv3d.py +217 -0
- ltx_video/models/autoencoders/latent_upsampler.py +203 -0
- ltx_video/models/autoencoders/pixel_norm.py +12 -0
- ltx_video/models/autoencoders/pixel_shuffle.py +33 -0
- ltx_video/models/autoencoders/vae.py +380 -0
- ltx_video/models/autoencoders/vae_encode.py +247 -0
- ltx_video/models/autoencoders/video_autoencoder.py +1045 -0
- ltx_video/models/transformers/__init__.py +0 -0
- ltx_video/models/transformers/attention.py +1264 -0
- ltx_video/models/transformers/embeddings.py +129 -0
- ltx_video/models/transformers/symmetric_patchifier.py +84 -0
- ltx_video/models/transformers/transformer3d.py +507 -0
- ltx_video/pipelines/__init__.py +0 -0
- ltx_video/pipelines/crf_compressor.py +50 -0
NOTICE.md
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NOTICE
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2025 Carlos Rodrigues dos Santos. All rights reserved.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## Aviso de Propriedade Intelectual e Licenciamento
|
| 8 |
+
|
| 9 |
+
### **Processo de Patenteamento em Andamento (EM PORTUGUÊS):**
|
| 10 |
+
|
| 11 |
+
O método e o sistema de orquestração de prompts denominados **ADUC (Automated Discovery and Orchestration of Complex tasks)**, conforme descritos neste documento e implementados neste software, estão atualmente em processo de patenteamento.
|
| 12 |
+
|
| 13 |
+
O titular dos direitos, Carlos Rodrigues dos Santos, está buscando proteção legal para as inovações chave da arquitetura ADUC, incluindo, mas não se limitando a:
|
| 14 |
+
|
| 15 |
+
* Fragmentação e escalonamento de solicitações que excedem limites de contexto de modelos de IA.
|
| 16 |
+
* Distribuição inteligente de sub-tarefas para especialistas heterogêneos.
|
| 17 |
+
* Gerenciamento de estado persistido com avaliação iterativa e realimentação para o planejamento de próximas etapas.
|
| 18 |
+
* Planejamento e roteamento sensível a custo, latência e requisitos de qualidade.
|
| 19 |
+
* O uso de "tokens universais" para comunicação agnóstica a modelos.
|
| 20 |
+
|
| 21 |
+
### **Reconhecimento e Implicações (EM PORTUGUÊS):**
|
| 22 |
+
|
| 23 |
+
Ao acessar ou utilizar este software e a arquitetura ADUC aqui implementada, você reconhece:
|
| 24 |
+
|
| 25 |
+
1. A natureza inovadora e a importância da arquitetura ADUC no campo da orquestração de prompts para IA.
|
| 26 |
+
2. Que a essência desta arquitetura, ou suas implementações derivadas, podem estar sujeitas a direitos de propriedade intelectual, incluindo patentes.
|
| 27 |
+
3. Que o uso comercial, a reprodução da lógica central da ADUC em sistemas independentes, ou a exploração direta da invenção sem o devido licenciamento podem infringir os direitos de patente pendente.
|
| 28 |
+
|
| 29 |
+
---
|
| 30 |
+
|
| 31 |
+
### **Patent Pending (IN ENGLISH):**
|
| 32 |
+
|
| 33 |
+
The method and system for prompt orchestration named **ADUC (Automated Discovery and Orchestration of Complex tasks)**, as described herein and implemented in this software, are currently in the process of being patented.
|
| 34 |
+
|
| 35 |
+
The rights holder, Carlos Rodrigues dos Santos, is seeking legal protection for the key innovations of the ADUC architecture, including, but not limited to:
|
| 36 |
+
|
| 37 |
+
* Fragmentation and scaling of requests exceeding AI model context limits.
|
| 38 |
+
* Intelligent distribution of sub-tasks to heterogeneous specialists.
|
| 39 |
+
* Persistent state management with iterative evaluation and feedback for planning subsequent steps.
|
| 40 |
+
* Cost, latency, and quality-aware planning and routing.
|
| 41 |
+
* The use of "universal tokens" for model-agnostic communication.
|
| 42 |
+
|
| 43 |
+
### **Acknowledgement and Implications (IN ENGLISH):**
|
| 44 |
+
|
| 45 |
+
By accessing or using this software and the ADUC architecture implemented herein, you acknowledge:
|
| 46 |
+
|
| 47 |
+
1. The innovative nature and significance of the ADUC architecture in the field of AI prompt orchestration.
|
| 48 |
+
2. That the essence of this architecture, or its derivative implementations, may be subject to intellectual property rights, including patents.
|
| 49 |
+
3. That commercial use, reproduction of ADUC's core logic in independent systems, or direct exploitation of the invention without proper licensing may infringe upon pending patent rights.
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## Licença AGPLv3
|
| 54 |
+
|
| 55 |
+
This program is free software: you can redistribute it and/or modify
|
| 56 |
+
it under the terms of the GNU Affero General Public License as published by
|
| 57 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 58 |
+
(at your option) any later version.
|
| 59 |
+
|
| 60 |
+
This program is distributed in the hope that it will be useful,
|
| 61 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 62 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 63 |
+
GNU Affero General Public License for more details.
|
| 64 |
+
|
| 65 |
+
You should have received a copy of the GNU Affero General Public License
|
| 66 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
**Contato para Consultas:**
|
| 71 |
+
|
| 72 |
+
Para mais informações sobre a arquitetura ADUC, o status do patenteamento, ou para discutir licenciamento para usos comerciais ou não conformes com a AGPLv3, por favor, entre em contato:
|
| 73 |
+
|
| 74 |
+
Carlos Rodrigues dos Santos
|
| 75 | |
| 76 |
+
Rua Eduardo Carlos Pereira, 4125, B1 Ap32, Curitiba, PR, Brazil, CEP 8102025
|
README.md
CHANGED
|
@@ -1,13 +1,210 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.44.0
|
| 8 |
app_file: app.py
|
| 9 |
-
pinned:
|
| 10 |
license: agpl-3.0
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Euia-AducSdr
|
| 3 |
+
emoji: 🎥
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
|
|
|
| 7 |
app_file: app.py
|
| 8 |
+
pinned: true
|
| 9 |
license: agpl-3.0
|
| 10 |
+
short_description: Uma implementação aberta e funcional da arquitetura ADUC-SDR
|
| 11 |
---
|
| 12 |
|
| 13 |
+
### 🇧🇷 Português
|
| 14 |
+
|
| 15 |
+
Uma implementação aberta e funcional da arquitetura ADUC-SDR (Arquitetura de Unificação Compositiva - Escala Dinâmica e Resiliente), projetada para a geração de vídeo coerente de longa duração. Este projeto materializa os princípios de fragmentação, navegação geométrica e um mecanismo de "eco causal 4bits memoria" para garantir a continuidade física e narrativa em sequências de vídeo geradas por múltiplos modelos de IA.
|
| 16 |
+
|
| 17 |
+
**Licença:** Este projeto é licenciado sob os termos da **GNU Affero General Public License v3.0**. Isto significa que se você usar este software (ou qualquer trabalho derivado) para fornecer um serviço através de uma rede, você é **obrigado a disponibilizar o código-fonte completo** da sua versão para os usuários desse serviço.
|
| 18 |
+
|
| 19 |
+
- **Copyright (C) 4 de Agosto de 2025, Carlos Rodrigues dos Santos**
|
| 20 |
+
- Uma cópia completa da licença pode ser encontrada no arquivo [LICENSE](LICENSE).
|
| 21 |
+
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
### 🇬🇧 English
|
| 25 |
+
|
| 26 |
+
An open and functional implementation of the ADUC-SDR (Architecture for Compositive Unification - Dynamic and Resilient Scaling) architecture, designed for long-form coherent video generation. This project materializes the principles of fragmentation, geometric navigation, and a "causal echo 4bits memori" mechanism to ensure physical and narrative continuity in video sequences generated by multiple AI models.
|
| 27 |
+
|
| 28 |
+
**License:** This project is licensed under the terms of the **GNU Affero General Public License v3.0**. This means that if you use this software (or any derivative work) to provide a service over a network, you are **required to make the complete source code** of your version available to the users of that service.
|
| 29 |
+
|
| 30 |
+
- **Copyright (C) August 4, 2025, Carlos Rodrigues dos Santos**
|
| 31 |
+
- A full copy of the license can be found in the [LICENSE](LICENSE) file.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## **Aviso de Propriedade Intelectual e Patenteamento**
|
| 36 |
+
|
| 37 |
+
### **Processo de Patenteamento em Andamento (EM PORTUGUÊS):**
|
| 38 |
+
|
| 39 |
+
A arquitetura e o método **ADUC (Automated Discovery and Orchestration of Complex tasks)**, conforme descritos neste projeto e nas reivindicações associadas, estão **atualmente em processo de patenteamento**.
|
| 40 |
+
|
| 41 |
+
O titular dos direitos, Carlos Rodrigues dos Santos, está buscando proteção legal para as inovações chave da arquitetura ADUC, que incluem, mas não se limitam a:
|
| 42 |
+
|
| 43 |
+
* Fragmentação e escalonamento de solicitações que excedem limites de contexto de modelos de IA.
|
| 44 |
+
* Distribuição inteligente de sub-tarefas para especialistas heterogêneos.
|
| 45 |
+
* Gerenciamento de estado persistido com avaliação iterativa e realimentação para o planejamento de próximas etapas.
|
| 46 |
+
* Planejamento e roteamento sensível a custo, latência e requisitos de qualidade.
|
| 47 |
+
* O uso de "tokens universais" para comunicação agnóstica a modelos.
|
| 48 |
+
|
| 49 |
+
Ao utilizar este software e a arquitetura ADUC aqui implementada, você reconhece a natureza inovadora desta arquitetura e que a **reprodução ou exploração da lógica central da ADUC em sistemas independentes pode infringir direitos de patente pendente.**
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
### **Patent Pending (IN ENGLISH):**
|
| 54 |
+
|
| 55 |
+
The **ADUC (Automated Discovery and Orchestration of Complex tasks)** architecture and method, as described in this project and its associated claims, are **currently in the process of being patented.**
|
| 56 |
+
|
| 57 |
+
The rights holder, Carlos Rodrigues dos Santos, is seeking legal protection for the key innovations of the ADUC architecture, including, but not limited to:
|
| 58 |
+
|
| 59 |
+
* Fragmentation and scaling of requests exceeding AI model context limits.
|
| 60 |
+
* Intelligent distribution of sub-tasks to heterogeneous specialists.
|
| 61 |
+
* Persistent state management with iterative evaluation and feedback for planning subsequent steps.
|
| 62 |
+
* Cost, latency, and quality-aware planning and routing.
|
| 63 |
+
* The use of "universal tokens" for model-agnostic communication.
|
| 64 |
+
|
| 65 |
+
By using this software and the ADUC architecture implemented herein, you acknowledge the innovative nature of this architecture and that **the reproduction or exploitation of ADUC's core logic in independent systems may infringe upon pending patent rights.**
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
### Detalhes Técnicos e Reivindicações da ADUC
|
| 70 |
+
|
| 71 |
+
#### 🇧🇷 Definição Curta (para Tese e Patente)
|
| 72 |
+
|
| 73 |
+
**ADUC** é um *framework pré-input* e *intermediário* de **gerenciamento de prompts** que:
|
| 74 |
+
|
| 75 |
+
1. **fragmenta** solicitações acima do limite de contexto de qualquer modelo,
|
| 76 |
+
2. **escala linearmente** (processo sequencial com memória persistida),
|
| 77 |
+
3. **distribui** sub-tarefas a **especialistas** (modelos/ferramentas heterogêneos), e
|
| 78 |
+
4. **realimenta** a próxima etapa com avaliação do que foi feito/esperado (LLM diretor).
|
| 79 |
+
|
| 80 |
+
Não é um modelo; é uma **camada orquestradora** plugável antes do input de modelos existentes (texto, imagem, áudio, vídeo), usando *tokens universais* e a tecnologia atual.
|
| 81 |
+
|
| 82 |
+
#### 🇬🇧 Short Definition (for Thesis and Patent)
|
| 83 |
+
|
| 84 |
+
**ADUC** is a *pre-input* and *intermediate* **prompt management framework** that:
|
| 85 |
+
|
| 86 |
+
1. **fragments** requests exceeding any model's context limit,
|
| 87 |
+
2. **scales linearly** (sequential process with persisted memory),
|
| 88 |
+
3. **distributes** sub-tasks to **specialists** (heterogeneous models/tools), and
|
| 89 |
+
4. **feeds back** to the next step with an evaluation of what was done/expected (director LLM).
|
| 90 |
+
|
| 91 |
+
It is not a model; it is a pluggable **orchestration layer** before the input of existing models (text, image, audio, video), using *universal tokens* and current technology.
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
#### 🇧🇷 Elementos Essenciais (Telegráfico)
|
| 96 |
+
|
| 97 |
+
* **Agnóstico a modelos:** opera com qualquer LLM/difusor/API.
|
| 98 |
+
* **Pré-input manager:** recebe pedido do usuário, **divide** em blocos ≤ limite de tokens, **prioriza**, **agenda** e **roteia**.
|
| 99 |
+
* **Memória persistida:** resultados/latentes/“eco” viram **estado compartilhado** para o próximo bloco (nada é ignorado).
|
| 100 |
+
* **Especialistas:** *routers* decidem quem faz o quê (ex.: “descrição → LLM-A”, “keyframe → Img-B”, “vídeo → Vid-C”).
|
| 101 |
+
* **Controle de qualidade:** LLM diretor compara *o que fez* × *o que deveria* × *o que falta* e **regenera objetivos** do próximo fragmento.
|
| 102 |
+
* **Custo/latência-aware:** planeja pela **VRAM/tempo/custo**, não tenta “abraçar tudo de uma vez”.
|
| 103 |
+
|
| 104 |
+
#### 🇬🇧 Essential Elements (Telegraphic)
|
| 105 |
+
|
| 106 |
+
* **Model-agnostic:** operates with any LLM/diffuser/API.
|
| 107 |
+
* **Pre-input manager:** receives user request, **divides** into blocks ≤ token limit, **prioritizes**, **schedules**, and **routes**.
|
| 108 |
+
* **Persisted memory:** results/latents/“echo” become **shared state** for the next block (nothing is ignored).
|
| 109 |
+
* **Specialists:** *routers* decide who does what (e.g., “description → LLM-A”, “keyframe → Img-B”, “video → Vid-C”).
|
| 110 |
+
* **Quality control:** director LLM compares *what was done* × *what should be done* × *what is missing* and **regenerates objectives** for the next fragment.
|
| 111 |
+
* **Cost/latency-aware:** plans by **VRAM/time/cost**, does not try to “embrace everything at once”.
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
#### 🇧🇷 Reivindicações Independentes (Método e Sistema)
|
| 116 |
+
|
| 117 |
+
**Reivindicação Independente (Método) — Versão Enxuta:**
|
| 118 |
+
|
| 119 |
+
1. **Método** de **orquestração de prompts** para execução de tarefas acima do limite de contexto de modelos de IA, compreendendo:
|
| 120 |
+
(a) **receber** uma solicitação que excede um limite de tokens;
|
| 121 |
+
(b) **analisar** a solicitação por um **LLM diretor** e **fragmentá-la** em sub-tarefas ≤ limite;
|
| 122 |
+
(c) **selecionar** especialistas de execução para cada sub-tarefa com base em capacidades declaradas;
|
| 123 |
+
(d) **gerar** prompts específicos por sub-tarefa em **tokens universais**, incluindo referências ao **estado persistido** de execuções anteriores;
|
| 124 |
+
(e) **executar sequencialmente** as sub-tarefas e **persistir** suas saídas como memória (incluindo latentes/eco/artefatos);
|
| 125 |
+
(f) **avaliar** automaticamente a saída versus metas declaradas e **regenerar objetivos** do próximo fragmento;
|
| 126 |
+
(g) **iterar** (b)–(f) até que os critérios de completude sejam atendidos, produzindo o resultado agregado;
|
| 127 |
+
em que o framework **escala linearmente** no tempo e armazenamento físico, **independente** da janela de contexto dos modelos subjacentes.
|
| 128 |
+
|
| 129 |
+
**Reivindicação Independente (Sistema):**
|
| 130 |
+
|
| 131 |
+
2. **Sistema** de orquestração de prompts, compreendendo: um **planejador LLM diretor**; um **roteador de especialistas**; um **banco de estado persistido** (incl. memória cinética para vídeo); um **gerador de prompts universais**; e um **módulo de avaliação/realimentação**, acoplados por uma **API pré-input** a modelos heterogêneos.
|
| 132 |
+
|
| 133 |
+
#### 🇬🇧 Independent Claims (Method and System)
|
| 134 |
+
|
| 135 |
+
**Independent Claim (Method) — Concise Version:**
|
| 136 |
+
|
| 137 |
+
1. A **method** for **prompt orchestration** for executing tasks exceeding AI model context limits, comprising:
|
| 138 |
+
(a) **receiving** a request that exceeds a token limit;
|
| 139 |
+
(b) **analyzing** the request by a **director LLM** and **fragmenting it** into sub-tasks ≤ the limit;
|
| 140 |
+
(c) **selecting** execution specialists for each sub-task based on declared capabilities;
|
| 141 |
+
(d) **generating** specific prompts per sub-task in **universal tokens**, including references to the **persisted state** of previous executions;
|
| 142 |
+
(e) **sequentially executing** the sub-tasks and **persisting** their outputs as memory (including latents/echo/artifacts);
|
| 143 |
+
(f) **automatically evaluating** the output against declared goals and **regenerating objectives** for the next fragment;
|
| 144 |
+
(g) **iterating** (b)–(f) until completion criteria are met, producing the aggregated result;
|
| 145 |
+
wherein the framework **scales linearly** in time and physical storage, **independent** of the context window of the underlying models.
|
| 146 |
+
|
| 147 |
+
**Independent Claim (System):**
|
| 148 |
+
|
| 149 |
+
2. A prompt orchestration **system**, comprising: a **director LLM planner**; a **specialist router**; a **persisted state bank** (incl. kinetic memory for video); a **universal prompt generator**; and an **evaluation/feedback module**, coupled via a **pre-input API** to heterogeneous models.
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
#### 🇧🇷 Dependentes Úteis
|
| 154 |
+
|
| 155 |
+
* (3) Onde o roteamento considera **custo/latência/VRAM** e metas de qualidade.
|
| 156 |
+
* (4) Onde o banco de estado inclui **eco cinético** para vídeo (últimos *n* frames/latentes/fluxo).
|
| 157 |
+
* (5) Onde a avaliação usa métricas específicas por domínio (Lflow, consistência semântica, etc.).
|
| 158 |
+
* (6) Onde *tokens universais* padronizam instruções entre especialistas.
|
| 159 |
+
* (7) Onde a orquestração decide **cut vs continuous** e **corte regenerativo** (Déjà-Vu) ao editar vídeo.
|
| 160 |
+
* (8) Onde o sistema **nunca descarta** conteúdo excedente: **reagenda** em novos fragmentos.
|
| 161 |
+
|
| 162 |
+
#### 🇬🇧 Useful Dependents
|
| 163 |
+
|
| 164 |
+
* (3) Wherein routing considers **cost/latency/VRAM** and quality goals.
|
| 165 |
+
* (4) Wherein the state bank includes **kinetic echo** for video (last *n* frames/latents/flow).
|
| 166 |
+
* (5) Wherein evaluation uses domain-specific metrics (Lflow, semantic consistency, etc.).
|
| 167 |
+
* (6) Wherein *universal tokens* standardize instructions between specialists.
|
| 168 |
+
* (7) Wherein orchestration decides **cut vs continuous** and **regenerative cut** (Déjà-Vu) when editing video.
|
| 169 |
+
* (8) Wherein the system **never discards** excess content: it **reschedules** it in new fragments.
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
#### 🇧🇷 Como isso conversa com SDR (Vídeo)
|
| 174 |
+
|
| 175 |
+
* **Eco Cinético**: é um **tipo de estado persistido** consumido pelo próximo passo.
|
| 176 |
+
* **Déjà-Vu (Corte Regenerativo)**: é **uma política de orquestração** aplicada quando há edição; ADUC decide, monta os prompts certos e chama o especialista de vídeo.
|
| 177 |
+
* **Cut vs Continuous**: decisão do **diretor** com base em estado + metas; ADUC roteia e garante a sobreposição/remoção final.
|
| 178 |
+
|
| 179 |
+
#### 🇬🇧 How this Converses with SDR (Video)
|
| 180 |
+
|
| 181 |
+
* **Kinetic Echo**: is a **type of persisted state** consumed by the next step.
|
| 182 |
+
* **Déjà-Vu (Regenerative Cut)**: is an **orchestration policy** applied during editing; ADUC decides, crafts the right prompts, and calls the video specialist.
|
| 183 |
+
* **Cut vs Continuous**: decision made by the **director** based on state + goals; ADUC routes and ensures the final overlap/removal.
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
#### 🇧🇷 Mensagem Clara ao Usuário (Experiência)
|
| 188 |
+
|
| 189 |
+
> “Seu pedido excede o limite X do modelo Y. Em vez de truncar silenciosamente, o **ADUC** dividirá e **entregará 100%** do conteúdo por etapas coordenadas.”
|
| 190 |
+
|
| 191 |
+
Isso é diferencial prático e jurídico: **não-obviedade** por transformar limite de contexto em **pipeline controlado**, com **persistência de estado** e **avaliação iterativa**.
|
| 192 |
+
|
| 193 |
+
#### 🇬🇧 Clear User Message (Experience)
|
| 194 |
+
|
| 195 |
+
> "Your request exceeds model Y's limit X. Instead of silently truncating, **ADUC** will divide and **deliver 100%** of the content through coordinated steps."
|
| 196 |
+
|
| 197 |
+
This is a practical and legal differentiator: **non-obviousness** by transforming context limits into a **controlled pipeline**, with **state persistence** and **iterative evaluation**.
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
### Contact / Contato / Contacto
|
| 202 |
+
|
| 203 |
+
- **Author / Autor:** Carlos Rodrigues dos Santos
|
| 204 |
+
- **Email:** [email protected]
|
| 205 |
+
- **GitHub:** [https://github.com/carlex22/Aduc-sdr](https://github.com/carlex22/Aduc-sdr)
|
| 206 |
+
- **Hugging Face Spaces:**
|
| 207 |
+
- [Ltx-SuperTime-60Secondos](https://huggingface.co/spaces/Carlexx/Ltx-SuperTime-60Secondos/)
|
| 208 |
+
- [Novinho](https://huggingface.co/spaces/Carlexxx/Novinho/)
|
| 209 |
+
|
| 210 |
+
---
|
configs/ltxv-13b-0.9.7-dev.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "bfloat16"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
|
| 18 |
+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
|
| 19 |
+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
|
| 20 |
+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
|
| 21 |
+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
|
| 22 |
+
num_inference_steps: 30
|
| 23 |
+
skip_final_inference_steps: 3
|
| 24 |
+
cfg_star_rescale: true
|
| 25 |
+
|
| 26 |
+
second_pass:
|
| 27 |
+
guidance_scale: [1]
|
| 28 |
+
stg_scale: [1]
|
| 29 |
+
rescaling_scale: [1]
|
| 30 |
+
guidance_timesteps: [1.0]
|
| 31 |
+
skip_block_list: [27]
|
| 32 |
+
num_inference_steps: 30
|
| 33 |
+
skip_initial_inference_steps: 17
|
| 34 |
+
cfg_star_rescale: true
|
configs/ltxv-13b-0.9.7-distilled.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.7-distilled.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "bfloat16"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
| 18 |
+
guidance_scale: 1
|
| 19 |
+
stg_scale: 0
|
| 20 |
+
rescaling_scale: 1
|
| 21 |
+
skip_block_list: [42]
|
| 22 |
+
|
| 23 |
+
second_pass:
|
| 24 |
+
timesteps: [0.9094, 0.7250, 0.4219]
|
| 25 |
+
guidance_scale: 1
|
| 26 |
+
stg_scale: 0
|
| 27 |
+
rescaling_scale: 1
|
| 28 |
+
skip_block_list: [42]
|
configs/ltxv-13b-0.9.8-dev-fp8.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.8-dev-fp8.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
|
| 18 |
+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
|
| 19 |
+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
|
| 20 |
+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
|
| 21 |
+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
|
| 22 |
+
num_inference_steps: 30
|
| 23 |
+
skip_final_inference_steps: 3
|
| 24 |
+
cfg_star_rescale: true
|
| 25 |
+
|
| 26 |
+
second_pass:
|
| 27 |
+
guidance_scale: [1]
|
| 28 |
+
stg_scale: [1]
|
| 29 |
+
rescaling_scale: [1]
|
| 30 |
+
guidance_timesteps: [1.0]
|
| 31 |
+
skip_block_list: [27]
|
| 32 |
+
num_inference_steps: 30
|
| 33 |
+
skip_initial_inference_steps: 17
|
| 34 |
+
cfg_star_rescale: true
|
configs/ltxv-13b-0.9.8-dev.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.8-dev.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "bfloat16"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
|
| 18 |
+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
|
| 19 |
+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
|
| 20 |
+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
|
| 21 |
+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
|
| 22 |
+
num_inference_steps: 30
|
| 23 |
+
skip_final_inference_steps: 3
|
| 24 |
+
cfg_star_rescale: true
|
| 25 |
+
|
| 26 |
+
second_pass:
|
| 27 |
+
guidance_scale: [1]
|
| 28 |
+
stg_scale: [1]
|
| 29 |
+
rescaling_scale: [1]
|
| 30 |
+
guidance_timesteps: [1.0]
|
| 31 |
+
skip_block_list: [27]
|
| 32 |
+
num_inference_steps: 30
|
| 33 |
+
skip_initial_inference_steps: 17
|
| 34 |
+
cfg_star_rescale: true
|
configs/ltxv-13b-0.9.8-distilled-fp8.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.8-distilled-fp8.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
| 18 |
+
guidance_scale: 1
|
| 19 |
+
stg_scale: 0
|
| 20 |
+
rescaling_scale: 1
|
| 21 |
+
skip_block_list: [42]
|
| 22 |
+
|
| 23 |
+
second_pass:
|
| 24 |
+
timesteps: [0.9094, 0.7250, 0.4219]
|
| 25 |
+
guidance_scale: 1
|
| 26 |
+
stg_scale: 0
|
| 27 |
+
rescaling_scale: 1
|
| 28 |
+
skip_block_list: [42]
|
| 29 |
+
tone_map_compression_ratio: 0.6
|
configs/ltxv-13b-0.9.8-distilled.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-13b-0.9.8-distilled.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "bfloat16"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
| 18 |
+
guidance_scale: 1
|
| 19 |
+
stg_scale: 0
|
| 20 |
+
rescaling_scale: 1
|
| 21 |
+
skip_block_list: [42]
|
| 22 |
+
|
| 23 |
+
second_pass:
|
| 24 |
+
timesteps: [0.9094, 0.7250, 0.4219]
|
| 25 |
+
guidance_scale: 1
|
| 26 |
+
stg_scale: 0
|
| 27 |
+
rescaling_scale: 1
|
| 28 |
+
skip_block_list: [42]
|
| 29 |
+
tone_map_compression_ratio: 0.6
|
configs/ltxv-2b-0.9.1.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: base
|
| 2 |
+
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
|
| 3 |
+
guidance_scale: 3
|
| 4 |
+
stg_scale: 1
|
| 5 |
+
rescaling_scale: 0.7
|
| 6 |
+
skip_block_list: [19]
|
| 7 |
+
num_inference_steps: 40
|
| 8 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 9 |
+
decode_timestep: 0.05
|
| 10 |
+
decode_noise_scale: 0.025
|
| 11 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 12 |
+
precision: "bfloat16"
|
| 13 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 14 |
+
prompt_enhancement_words_threshold: 120
|
| 15 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 16 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 17 |
+
stochastic_sampling: false
|
configs/ltxv-2b-0.9.5.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: base
|
| 2 |
+
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
|
| 3 |
+
guidance_scale: 3
|
| 4 |
+
stg_scale: 1
|
| 5 |
+
rescaling_scale: 0.7
|
| 6 |
+
skip_block_list: [19]
|
| 7 |
+
num_inference_steps: 40
|
| 8 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 9 |
+
decode_timestep: 0.05
|
| 10 |
+
decode_noise_scale: 0.025
|
| 11 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 12 |
+
precision: "bfloat16"
|
| 13 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 14 |
+
prompt_enhancement_words_threshold: 120
|
| 15 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 16 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 17 |
+
stochastic_sampling: false
|
configs/ltxv-2b-0.9.6-dev.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: base
|
| 2 |
+
checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
|
| 3 |
+
guidance_scale: 3
|
| 4 |
+
stg_scale: 1
|
| 5 |
+
rescaling_scale: 0.7
|
| 6 |
+
skip_block_list: [19]
|
| 7 |
+
num_inference_steps: 40
|
| 8 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 9 |
+
decode_timestep: 0.05
|
| 10 |
+
decode_noise_scale: 0.025
|
| 11 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 12 |
+
precision: "bfloat16"
|
| 13 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 14 |
+
prompt_enhancement_words_threshold: 120
|
| 15 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 16 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 17 |
+
stochastic_sampling: false
|
configs/ltxv-2b-0.9.6-distilled.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: base
|
| 2 |
+
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
|
| 3 |
+
guidance_scale: 1
|
| 4 |
+
stg_scale: 0
|
| 5 |
+
rescaling_scale: 1
|
| 6 |
+
num_inference_steps: 8
|
| 7 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 8 |
+
decode_timestep: 0.05
|
| 9 |
+
decode_noise_scale: 0.025
|
| 10 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 11 |
+
precision: "bfloat16"
|
| 12 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 13 |
+
prompt_enhancement_words_threshold: 120
|
| 14 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 15 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 16 |
+
stochastic_sampling: true
|
configs/ltxv-2b-0.9.8-distilled-fp8.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-2b-0.9.8-distilled-fp8.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "float8_e4m3fn" # options: "float8_e4m3fn", "bfloat16", "mixed_precision"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
| 18 |
+
guidance_scale: 1
|
| 19 |
+
stg_scale: 0
|
| 20 |
+
rescaling_scale: 1
|
| 21 |
+
skip_block_list: [42]
|
| 22 |
+
|
| 23 |
+
second_pass:
|
| 24 |
+
timesteps: [0.9094, 0.7250, 0.4219]
|
| 25 |
+
guidance_scale: 1
|
| 26 |
+
stg_scale: 0
|
| 27 |
+
rescaling_scale: 1
|
| 28 |
+
skip_block_list: [42]
|
configs/ltxv-2b-0.9.8-distilled.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: multi-scale
|
| 2 |
+
checkpoint_path: "ltxv-2b-0.9.8-distilled.safetensors"
|
| 3 |
+
downscale_factor: 0.6666666
|
| 4 |
+
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.8.safetensors"
|
| 5 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 6 |
+
decode_timestep: 0.05
|
| 7 |
+
decode_noise_scale: 0.025
|
| 8 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 9 |
+
precision: "bfloat16"
|
| 10 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 11 |
+
prompt_enhancement_words_threshold: 120
|
| 12 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 13 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 14 |
+
stochastic_sampling: false
|
| 15 |
+
|
| 16 |
+
first_pass:
|
| 17 |
+
timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
|
| 18 |
+
guidance_scale: 1
|
| 19 |
+
stg_scale: 0
|
| 20 |
+
rescaling_scale: 1
|
| 21 |
+
skip_block_list: [42]
|
| 22 |
+
|
| 23 |
+
second_pass:
|
| 24 |
+
timesteps: [0.9094, 0.7250, 0.4219]
|
| 25 |
+
guidance_scale: 1
|
| 26 |
+
stg_scale: 0
|
| 27 |
+
rescaling_scale: 1
|
| 28 |
+
skip_block_list: [42]
|
configs/ltxv-2b-0.9.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pipeline_type: base
|
| 2 |
+
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
|
| 3 |
+
guidance_scale: 3
|
| 4 |
+
stg_scale: 1
|
| 5 |
+
rescaling_scale: 0.7
|
| 6 |
+
skip_block_list: [19]
|
| 7 |
+
num_inference_steps: 40
|
| 8 |
+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
|
| 9 |
+
decode_timestep: 0.05
|
| 10 |
+
decode_noise_scale: 0.025
|
| 11 |
+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
|
| 12 |
+
precision: "bfloat16"
|
| 13 |
+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
|
| 14 |
+
prompt_enhancement_words_threshold: 120
|
| 15 |
+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
|
| 16 |
+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
|
| 17 |
+
stochastic_sampling: false
|
deformes4D_engine.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# deformes4D_engine.py
|
| 2 |
+
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
|
| 3 |
+
#
|
| 4 |
+
# MODIFICATIONS FOR ADUC-SDR:
|
| 5 |
+
# Copyright (C) 2025 Carlos Rodrigues dos Santos. All rights reserved.
|
| 6 |
+
#
|
| 7 |
+
# This file is part of the ADUC-SDR project. It contains the core logic for
|
| 8 |
+
# video fragment generation, latent manipulation, and dynamic editing,
|
| 9 |
+
# governed by the ADUC orchestrator.
|
| 10 |
+
# This component is licensed under the GNU Affero General Public License v3.0.
|
| 11 |
+
#
|
| 12 |
+
# AVISO DE PATENTE PENDENTE: O método e sistema ADUC implementado neste
|
| 13 |
+
# software está em processo de patenteamento. Consulte NOTICE.md.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import imageio
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import logging
|
| 21 |
+
from PIL import Image, ImageOps
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
import gradio as gr
|
| 24 |
+
import subprocess
|
| 25 |
+
import random
|
| 26 |
+
import gc
|
| 27 |
+
|
| 28 |
+
from audio_specialist import audio_specialist_singleton
|
| 29 |
+
from ltx_manager_helpers import ltx_manager_singleton
|
| 30 |
+
from flux_kontext_helpers import flux_kontext_singleton
|
| 31 |
+
from gemini_helpers import gemini_singleton
|
| 32 |
+
from ltx_video.models.autoencoders.vae_encode import vae_encode, vae_decode
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class LatentConditioningItem:
|
| 38 |
+
latent_tensor: torch.Tensor
|
| 39 |
+
media_frame_number: int
|
| 40 |
+
conditioning_strength: float
|
| 41 |
+
|
| 42 |
+
class Deformes4DEngine:
|
| 43 |
+
def __init__(self, ltx_manager, workspace_dir="deformes_workspace"):
|
| 44 |
+
self.ltx_manager = ltx_manager
|
| 45 |
+
self.workspace_dir = workspace_dir
|
| 46 |
+
self._vae = None
|
| 47 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 48 |
+
logger.info("Especialista Deformes4D (SDR Executor) inicializado.")
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def vae(self):
|
| 52 |
+
if self._vae is None:
|
| 53 |
+
self._vae = self.ltx_manager.workers[0].pipeline.vae
|
| 54 |
+
self._vae.to(self.device); self._vae.eval()
|
| 55 |
+
return self._vae
|
| 56 |
+
|
| 57 |
+
def save_latent_tensor(self, tensor: torch.Tensor, path: str):
|
| 58 |
+
torch.save(tensor.cpu(), path)
|
| 59 |
+
logger.info(f"Tensor latente salvo em: {path}")
|
| 60 |
+
|
| 61 |
+
def load_latent_tensor(self, path: str) -> torch.Tensor:
|
| 62 |
+
tensor = torch.load(path, map_location=self.device)
|
| 63 |
+
logger.info(f"Tensor latente carregado de: {path} para o dispositivo {self.device}")
|
| 64 |
+
return tensor
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def pixels_to_latents(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
tensor = tensor.to(self.device, dtype=self.vae.dtype)
|
| 69 |
+
return vae_encode(tensor, self.vae, vae_per_channel_normalize=True)
|
| 70 |
+
|
| 71 |
+
@torch.no_grad()
|
| 72 |
+
def latents_to_pixels(self, latent_tensor: torch.Tensor, decode_timestep: float = 0.05) -> torch.Tensor:
|
| 73 |
+
latent_tensor = latent_tensor.to(self.device, dtype=self.vae.dtype)
|
| 74 |
+
timestep_tensor = torch.tensor([decode_timestep] * latent_tensor.shape[0], device=self.device, dtype=latent_tensor.dtype)
|
| 75 |
+
return vae_decode(latent_tensor, self.vae, is_video=True, timestep=timestep_tensor, vae_per_channel_normalize=True)
|
| 76 |
+
|
| 77 |
+
def save_video_from_tensor(self, video_tensor: torch.Tensor, path: str, fps: int = 24):
|
| 78 |
+
if video_tensor is None or video_tensor.ndim != 5 or video_tensor.shape[2] == 0: return
|
| 79 |
+
video_tensor = video_tensor.squeeze(0).permute(1, 2, 3, 0)
|
| 80 |
+
video_tensor = (video_tensor.clamp(-1, 1) + 1) / 2.0
|
| 81 |
+
video_np = (video_tensor.detach().cpu().float().numpy() * 255).astype(np.uint8)
|
| 82 |
+
with imageio.get_writer(path, fps=fps, codec='libx264', quality=8) as writer:
|
| 83 |
+
for frame in video_np: writer.append_data(frame)
|
| 84 |
+
logger.info(f"Vídeo salvo em: {path}")
|
| 85 |
+
|
| 86 |
+
def _preprocess_image_for_latent_conversion(self, image: Image.Image, target_resolution: tuple) -> Image.Image:
|
| 87 |
+
if image.size != target_resolution:
|
| 88 |
+
logger.info(f" - AÇÃO: Redimensionando imagem de {image.size} para {target_resolution} antes da conversão para latente.")
|
| 89 |
+
return ImageOps.fit(image, target_resolution, Image.Resampling.LANCZOS)
|
| 90 |
+
return image
|
| 91 |
+
|
| 92 |
+
def pil_to_latent(self, pil_image: Image.Image) -> torch.Tensor:
|
| 93 |
+
image_np = np.array(pil_image).astype(np.float32) / 255.0
|
| 94 |
+
tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0).unsqueeze(2)
|
| 95 |
+
tensor = (tensor * 2.0) - 1.0
|
| 96 |
+
return self.pixels_to_latents(tensor)
|
| 97 |
+
|
| 98 |
+
def _generate_video_and_audio_from_latents(self, latent_tensor, audio_prompt, base_name):
|
| 99 |
+
silent_video_path = os.path.join(self.workspace_dir, f"{base_name}_silent.mp4")
|
| 100 |
+
pixel_tensor = self.latents_to_pixels(latent_tensor)
|
| 101 |
+
self.save_video_from_tensor(pixel_tensor, silent_video_path, fps=24)
|
| 102 |
+
del pixel_tensor; gc.collect()
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
result = subprocess.run(
|
| 106 |
+
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", silent_video_path],
|
| 107 |
+
capture_output=True, text=True, check=True)
|
| 108 |
+
frag_duration = float(result.stdout.strip())
|
| 109 |
+
except (subprocess.CalledProcessError, ValueError, FileNotFoundError):
|
| 110 |
+
logger.warning(f"ffprobe falhou em {os.path.basename(silent_video_path)}. Calculando duração manualmente.")
|
| 111 |
+
num_pixel_frames = latent_tensor.shape[2] * 8
|
| 112 |
+
frag_duration = num_pixel_frames / 24.0
|
| 113 |
+
|
| 114 |
+
video_with_audio_path = audio_specialist_singleton.generate_audio_for_video(
|
| 115 |
+
video_path=silent_video_path, prompt=audio_prompt,
|
| 116 |
+
duration_seconds=frag_duration)
|
| 117 |
+
|
| 118 |
+
if os.path.exists(silent_video_path):
|
| 119 |
+
os.remove(silent_video_path)
|
| 120 |
+
return video_with_audio_path
|
| 121 |
+
|
| 122 |
+
def _generate_latent_tensor_internal(self, conditioning_items, ltx_params, target_resolution, total_frames_to_generate):
|
| 123 |
+
final_ltx_params = {
|
| 124 |
+
**ltx_params,
|
| 125 |
+
'width': target_resolution[0], 'height': target_resolution[1],
|
| 126 |
+
'video_total_frames': total_frames_to_generate, 'video_fps': 24,
|
| 127 |
+
'current_fragment_index': int(time.time()),
|
| 128 |
+
'conditioning_items_data': conditioning_items
|
| 129 |
+
}
|
| 130 |
+
new_full_latents, _ = self.ltx_manager.generate_latent_fragment(**final_ltx_params)
|
| 131 |
+
return new_full_latents
|
| 132 |
+
|
| 133 |
+
def concatenate_videos_ffmpeg(self, video_paths: list[str], output_path: str) -> str:
|
| 134 |
+
if not video_paths:
|
| 135 |
+
raise gr.Error("Nenhum fragmento de vídeo para montar.")
|
| 136 |
+
list_file_path = os.path.join(self.workspace_dir, "concat_list.txt")
|
| 137 |
+
with open(list_file_path, 'w', encoding='utf-8') as f:
|
| 138 |
+
for path in video_paths:
|
| 139 |
+
f.write(f"file '{os.path.abspath(path)}'\n")
|
| 140 |
+
cmd_list = ['ffmpeg', '-y', '-f', 'concat', '-safe', '0', '-i', list_file_path, '-c', 'copy', output_path]
|
| 141 |
+
logger.info("Executando concatenação FFmpeg...")
|
| 142 |
+
try:
|
| 143 |
+
subprocess.run(cmd_list, check=True, capture_output=True, text=True)
|
| 144 |
+
except subprocess.CalledProcessError as e:
|
| 145 |
+
logger.error(f"Erro no FFmpeg: {e.stderr}")
|
| 146 |
+
raise gr.Error(f"Falha na montagem final do vídeo. Detalhes: {e.stderr}")
|
| 147 |
+
return output_path
|
| 148 |
+
|
| 149 |
+
def generate_full_movie(self,
|
| 150 |
+
keyframes: list,
|
| 151 |
+
global_prompt: str,
|
| 152 |
+
storyboard: list,
|
| 153 |
+
seconds_per_fragment: float,
|
| 154 |
+
overlap_percent: int,
|
| 155 |
+
echo_frames: int,
|
| 156 |
+
handler_strength: float,
|
| 157 |
+
destination_convergence_strength: float,
|
| 158 |
+
base_ltx_params: dict,
|
| 159 |
+
video_resolution: int,
|
| 160 |
+
use_continuity_director: bool,
|
| 161 |
+
progress: gr.Progress = gr.Progress()):
|
| 162 |
+
|
| 163 |
+
keyframe_paths = [item[0] if isinstance(item, tuple) else item for item in keyframes]
|
| 164 |
+
video_clips_paths, story_history, audio_history = [], "", "This is the beginning of the film."
|
| 165 |
+
target_resolution_tuple = (video_resolution, video_resolution)
|
| 166 |
+
n_trim_latents = 24 #self._quantize_to_multiple(int(seconds_per_fragment * 24 * (overlap_percent / 100.0)), 8)
|
| 167 |
+
echo_frames = 8
|
| 168 |
+
|
| 169 |
+
previous_latents_path = None
|
| 170 |
+
num_transitions_to_generate = len(keyframe_paths) - 1
|
| 171 |
+
|
| 172 |
+
for i in range(num_transitions_to_generate):
|
| 173 |
+
progress((i + 1) / num_transitions_to_generate, desc=f"Produzindo Transição {i+1}/{num_transitions_to_generate}")
|
| 174 |
+
|
| 175 |
+
start_keyframe_path = keyframe_paths[i]
|
| 176 |
+
destination_keyframe_path = keyframe_paths[i+1]
|
| 177 |
+
present_scene_desc = storyboard[i]
|
| 178 |
+
|
| 179 |
+
is_first_fragment = previous_latents_path is None
|
| 180 |
+
if is_first_fragment:
|
| 181 |
+
transition_type = "start"
|
| 182 |
+
motion_prompt = gemini_singleton.get_initial_motion_prompt(
|
| 183 |
+
global_prompt, start_keyframe_path, destination_keyframe_path, present_scene_desc
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
past_keyframe_path = keyframe_paths[i-1]
|
| 187 |
+
past_scene_desc = storyboard[i-1]
|
| 188 |
+
future_scene_desc = storyboard[i+1] if (i+1) < len(storyboard) else "A cena final."
|
| 189 |
+
decision = gemini_singleton.get_cinematic_decision(
|
| 190 |
+
global_prompt=global_prompt, story_history=story_history,
|
| 191 |
+
past_keyframe_path=past_keyframe_path, present_keyframe_path=start_keyframe_path,
|
| 192 |
+
future_keyframe_path=destination_keyframe_path, past_scene_desc=past_scene_desc,
|
| 193 |
+
present_scene_desc=present_scene_desc, future_scene_desc=future_scene_desc
|
| 194 |
+
)
|
| 195 |
+
transition_type, motion_prompt = decision["transition_type"], decision["motion_prompt"]
|
| 196 |
+
|
| 197 |
+
story_history += f"\n- Ato {i+1} ({transition_type}): {motion_prompt}"
|
| 198 |
+
|
| 199 |
+
if use_continuity_director: # Assume-se que este checkbox controla os diretores de vídeo e som
|
| 200 |
+
if is_first_fragment:
|
| 201 |
+
audio_prompt = gemini_singleton.get_sound_director_prompt(
|
| 202 |
+
audio_history=audio_history,
|
| 203 |
+
past_keyframe_path=start_keyframe_path, present_keyframe_path=start_keyframe_path,
|
| 204 |
+
future_keyframe_path=destination_keyframe_path, present_scene_desc=present_scene_desc,
|
| 205 |
+
motion_prompt=motion_prompt, future_scene_desc=storyboard[i+1] if (i+1) < len(storyboard) else "The final scene."
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
audio_prompt = gemini_singleton.get_sound_director_prompt(
|
| 209 |
+
audio_history=audio_history, past_keyframe_path=keyframe_paths[i-1],
|
| 210 |
+
present_keyframe_path=start_keyframe_path, future_keyframe_path=destination_keyframe_path,
|
| 211 |
+
present_scene_desc=present_scene_desc, motion_prompt=motion_prompt,
|
| 212 |
+
future_scene_desc=storyboard[i+1] if (i+1) < len(storyboard) else "The final scene."
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
audio_prompt = present_scene_desc # Fallback para o prompt da cena se o diretor de som estiver desligado
|
| 216 |
+
|
| 217 |
+
audio_history = audio_prompt
|
| 218 |
+
|
| 219 |
+
conditioning_items = []
|
| 220 |
+
current_ltx_params = {**base_ltx_params, "handler_strength": handler_strength, "motion_prompt": motion_prompt}
|
| 221 |
+
total_frames_to_generate = self._quantize_to_multiple(int(seconds_per_fragment * 24), 8) + 1
|
| 222 |
+
|
| 223 |
+
if is_first_fragment:
|
| 224 |
+
img_start = self._preprocess_image_for_latent_conversion(Image.open(start_keyframe_path).convert("RGB"), target_resolution_tuple)
|
| 225 |
+
start_latent = self.pil_to_latent(img_start)
|
| 226 |
+
conditioning_items.append(LatentConditioningItem(start_latent, 0, 1.0))
|
| 227 |
+
if transition_type != "cut":
|
| 228 |
+
img_dest = self._preprocess_image_for_latent_conversion(Image.open(destination_keyframe_path).convert("RGB"), target_resolution_tuple)
|
| 229 |
+
destination_latent = self.pil_to_latent(img_dest)
|
| 230 |
+
conditioning_items.append(LatentConditioningItem(destination_latent, total_frames_to_generate - 1, destination_convergence_strength))
|
| 231 |
+
else:
|
| 232 |
+
previous_latents = self.load_latent_tensor(previous_latents_path)
|
| 233 |
+
handler_latent = previous_latents[:, :, -1:, :, :]
|
| 234 |
+
trimmed_for_echo = previous_latents[:, :, :-n_trim_latents, :, :] if n_trim_latents > 0 and previous_latents.shape[2] > n_trim_latents else previous_latents
|
| 235 |
+
echo_latents = trimmed_for_echo[:, :, -echo_frames:, :, :]
|
| 236 |
+
handler_frame_position = n_trim_latents + echo_frames
|
| 237 |
+
|
| 238 |
+
conditioning_items = []
|
| 239 |
+
|
| 240 |
+
for i, echo_latent in enumerate(echo_frames):
|
| 241 |
+
if i == 0:
|
| 242 |
+
weight = 1.0
|
| 243 |
+
else:
|
| 244 |
+
weight = random.uniform(0.2, 0.7)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
conditioning_items.append(LatentConditioningItem(echo_latent, 0, weight))
|
| 249 |
+
#conditioning_items.append(LatentConditioningItem(echo_latents, 0, 1.0))
|
| 250 |
+
conditioning_items.append(LatentConditioningItem(handler_latent, handler_frame_position, handler_strength))
|
| 251 |
+
del previous_latents, handler_latent, trimmed_for_echo, echo_latents; gc.collect()
|
| 252 |
+
if transition_type == "continuous":
|
| 253 |
+
img_dest = self._preprocess_image_for_latent_conversion(Image.open(destination_keyframe_path).convert("RGB"), target_resolution_tuple)
|
| 254 |
+
destination_latent = self.pil_to_latent(img_dest)
|
| 255 |
+
conditioning_items.append(LatentConditioningItem(destination_latent, total_frames_to_generate - 1, destination_convergence_strength))
|
| 256 |
+
|
| 257 |
+
new_full_latents = self._generate_latent_tensor_internal(conditioning_items, current_ltx_params, target_resolution_tuple, total_frames_to_generate)
|
| 258 |
+
|
| 259 |
+
base_name = f"fragment_{i}_{int(time.time())}"
|
| 260 |
+
new_full_latents_path = os.path.join(self.workspace_dir, f"{base_name}_full.pt")
|
| 261 |
+
self.save_latent_tensor(new_full_latents, new_full_latents_path)
|
| 262 |
+
|
| 263 |
+
previous_latents_path = new_full_latents_path
|
| 264 |
+
|
| 265 |
+
latents_for_video = new_full_latents
|
| 266 |
+
|
| 267 |
+
if not is_first_fragment:
|
| 268 |
+
if echo_frames > 0 and latents_for_video.shape[2] > echo_frames: latents_for_video = latents_for_video[:, :, echo_frames:, :, :]
|
| 269 |
+
if n_trim_latents > 0 and latents_for_video.shape[2] > n_trim_latents: latents_for_video = latents_for_video[:, :, :-n_trim_latents, :, :]
|
| 270 |
+
else:
|
| 271 |
+
if n_trim_latents > 0 and latents_for_video.shape[2] > n_trim_latents: latents_for_video = latents_for_video[:, :, :-n_trim_latents, :, :]
|
| 272 |
+
|
| 273 |
+
video_with_audio_path = self._generate_video_and_audio_from_latents(latents_for_video, audio_prompt, base_name)
|
| 274 |
+
video_clips_paths.append(video_with_audio_path)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
if transition_type == "cut":
|
| 278 |
+
previous_latents_path = None
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
yield {"fragment_path": video_with_audio_path}
|
| 282 |
+
|
| 283 |
+
final_movie_path = os.path.join(self.workspace_dir, f"final_movie_{int(time.time())}.mp4")
|
| 284 |
+
self.concatenate_videos_ffmpeg(video_clips_paths, final_movie_path)
|
| 285 |
+
|
| 286 |
+
logger.info(f"Filme completo salvo em: {final_movie_path}")
|
| 287 |
+
yield {"final_path": final_movie_path}
|
| 288 |
+
|
| 289 |
+
def _quantize_to_multiple(self, n, m):
|
| 290 |
+
if m == 0: return n
|
| 291 |
+
quantized = int(round(n / m) * m)
|
| 292 |
+
return m if n > 0 and quantized == 0 else quantized
|
dreamo/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
dreamo/README.md
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🛠️ helpers/ - Ferramentas de IA de Terceiros Adaptadas para ADUC-SDR
|
| 2 |
+
|
| 3 |
+
Esta pasta contém implementações adaptadas de modelos e utilitários de IA de terceiros, que servem como "especialistas" ou "ferramentas" de baixo nível para a arquitetura ADUC-SDR.
|
| 4 |
+
|
| 5 |
+
**IMPORTANTE:** O conteúdo desta pasta é de autoria de seus respectivos idealizadores e desenvolvedores originais. Esta pasta **NÃO FAZ PARTE** do projeto principal ADUC-SDR em termos de sua arquitetura inovadora. Ela serve como um repositório para as **dependências diretas e modificadas** que os `DeformesXDEngines` (os estágios do "foguete" ADUC-SDR) invocam para realizar tarefas específicas (geração de imagem, vídeo, áudio).
|
| 6 |
+
|
| 7 |
+
As modificações realizadas nos arquivos aqui presentes visam principalmente:
|
| 8 |
+
1. **Adaptação de Interfaces:** Padronizar as interfaces para que se encaixem no fluxo de orquestração do ADUC-SDR.
|
| 9 |
+
2. **Gerenciamento de Recursos:** Integrar lógicas de carregamento/descarregamento de modelos (GPU management) e configurações via arquivos YAML.
|
| 10 |
+
3. **Otimização de Fluxo:** Ajustar as pipelines para aceitar formatos de entrada mais eficientes (ex: tensores pré-codificados em vez de caminhos de mídia, pulando etapas de codificação/decodificação redundantes).
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 📄 Licenciamento
|
| 15 |
+
|
| 16 |
+
O conteúdo original dos projetos listados abaixo é licenciado sob a **Licença Apache 2.0**, ou outra licença especificada pelos autores originais. Todas as modificações e o uso desses arquivos dentro da estrutura `helpers/` do projeto ADUC-SDR estão em conformidade com os termos da **Licença Apache 2.0**.
|
| 17 |
+
|
| 18 |
+
As licenças originais dos projetos podem ser encontradas nas suas respectivas fontes ou nos subdiretórios `incl_licenses/` dentro de cada módulo adaptado.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## 🛠️ API dos Helpers e Guia de Uso
|
| 23 |
+
|
| 24 |
+
Esta seção detalha como cada helper (agente especialista) deve ser utilizado dentro do ecossistema ADUC-SDR. Todos os agentes são instanciados como **singletons** no `hardware_manager.py` para garantir o gerenciamento centralizado de recursos de GPU.
|
| 25 |
+
|
| 26 |
+
### **gemini_helpers.py (GeminiAgent)**
|
| 27 |
+
|
| 28 |
+
* **Propósito:** Atua como o "Oráculo de Síntese Adaptativo", responsável por todas as tarefas de processamento de linguagem natural, como criação de storyboards, geração de prompts, e tomada de decisões narrativas.
|
| 29 |
+
* **Singleton Instance:** `gemini_agent_singleton`
|
| 30 |
+
* **Construtor:** `GeminiAgent()`
|
| 31 |
+
* Lê `configs/gemini_config.yaml` para obter o nome do modelo, parâmetros de inferência e caminhos de templates de prompt. A chave da API é lida da variável de ambiente `GEMINI_API_KEY`.
|
| 32 |
+
* **Métodos Públicos:**
|
| 33 |
+
* `generate_storyboard(prompt: str, num_keyframes: int, ref_image_paths: list[str])`
|
| 34 |
+
* **Inputs:**
|
| 35 |
+
* `prompt`: A ideia geral do filme (string).
|
| 36 |
+
* `num_keyframes`: O número de cenas a serem geradas (int).
|
| 37 |
+
* `ref_image_paths`: Lista de caminhos para as imagens de referência (list[str]).
|
| 38 |
+
* **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de strings do storyboard e um relatório textual da operação).
|
| 39 |
+
* `select_keyframes_from_pool(storyboard: list, base_image_paths: list[str], pool_image_paths: list[str])`
|
| 40 |
+
* **Inputs:**
|
| 41 |
+
* `storyboard`: A lista de strings do storyboard gerado.
|
| 42 |
+
* `base_image_paths`: Imagens de referência base (list[str]).
|
| 43 |
+
* `pool_image_paths`: O "banco de imagens" de onde selecionar (list[str]).
|
| 44 |
+
* **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de caminhos de imagens selecionadas e um relatório textual).
|
| 45 |
+
* `get_anticipatory_keyframe_prompt(...)`
|
| 46 |
+
* **Inputs:** Contexto narrativo e visual para gerar um prompt de imagem.
|
| 47 |
+
* **Output:** `tuple[str, str]` (Uma tupla contendo o prompt gerado para o modelo de imagem e um relatório textual).
|
| 48 |
+
* `get_initial_motion_prompt(...)`
|
| 49 |
+
* **Inputs:** Contexto narrativo e visual para a primeira transição de vídeo.
|
| 50 |
+
* **Output:** `tuple[str, str]` (Uma tupla contendo o prompt de movimento gerado e um relatório textual).
|
| 51 |
+
* `get_transition_decision(...)`
|
| 52 |
+
* **Inputs:** Contexto narrativo e visual para uma transição de vídeo intermediária.
|
| 53 |
+
* **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"transition_type": "...", "motion_prompt": "..."}` e um relatório textual).
|
| 54 |
+
* `generate_audio_prompts(...)`
|
| 55 |
+
* **Inputs:** Contexto narrativo global.
|
| 56 |
+
* **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"music_prompt": "...", "sfx_prompt": "..."}` e um relatório textual).
|
| 57 |
+
|
| 58 |
+
### **flux_kontext_helpers.py (FluxPoolManager)**
|
| 59 |
+
|
| 60 |
+
* **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline FluxKontext. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
|
| 61 |
+
* **Singleton Instance:** `flux_kontext_singleton`
|
| 62 |
+
* **Construtor:** `FluxPoolManager(device_ids: list[str], flux_config_file: str)`
|
| 63 |
+
* Lê `configs/flux_config.yaml`.
|
| 64 |
+
* **Método Público:**
|
| 65 |
+
* `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int, seed: int = 42, callback: callable = None)`
|
| 66 |
+
* **Inputs:**
|
| 67 |
+
* `prompt`: Prompt textual para guiar a geração (string).
|
| 68 |
+
* `reference_images`: Lista de objetos `PIL.Image` como referência visual.
|
| 69 |
+
* `width`, `height`: Dimensões da imagem de saída (int).
|
| 70 |
+
* `seed`: Semente para reprodutibilidade (int).
|
| 71 |
+
* `callback`: Função de callback opcional para monitorar o progresso.
|
| 72 |
+
* **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
|
| 73 |
+
|
| 74 |
+
### **dreamo_helpers.py (DreamOAgent)**
|
| 75 |
+
|
| 76 |
+
* **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline DreamO, com capacidades avançadas de edição e estilo a partir de referências.
|
| 77 |
+
* **Singleton Instance:** `dreamo_agent_singleton`
|
| 78 |
+
* **Construtor:** `DreamOAgent(device_id: str = None)`
|
| 79 |
+
* Lê `configs/dreamo_config.yaml`.
|
| 80 |
+
* **Método Público:**
|
| 81 |
+
* `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int)`
|
| 82 |
+
* **Inputs:**
|
| 83 |
+
* `prompt`: Prompt textual para guiar a geração (string).
|
| 84 |
+
* `reference_images`: Lista de objetos `PIL.Image` como referência visual. A lógica interna atribui a primeira imagem como `style` e as demais como `ip`.
|
| 85 |
+
* `width`, `height`: Dimensões da imagem de saída (int).
|
| 86 |
+
* **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
|
| 87 |
+
|
| 88 |
+
### **ltx_manager_helpers.py (LtxPoolManager)**
|
| 89 |
+
|
| 90 |
+
* **Propósito:** Especialista na geração de fragmentos de vídeo no espaço latente usando a pipeline LTX-Video. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
|
| 91 |
+
* **Singleton Instance:** `ltx_manager_singleton`
|
| 92 |
+
* **Construtor:** `LtxPoolManager(device_ids: list[str], ltx_model_config_file: str, ltx_global_config_file: str)`
|
| 93 |
+
* Lê o `ltx_global_config_file` e o `ltx_model_config_file` para configurar a pipeline.
|
| 94 |
+
* **Método Público:**
|
| 95 |
+
* `generate_latent_fragment(**kwargs)`
|
| 96 |
+
* **Inputs:** Dicionário de keyword arguments (`kwargs`) contendo todos os parâmetros da pipeline LTX, incluindo:
|
| 97 |
+
* `height`, `width`: Dimensões do vídeo (int).
|
| 98 |
+
* `video_total_frames`: Número total de frames a serem gerados (int).
|
| 99 |
+
* `video_fps`: Frames por segundo (int).
|
| 100 |
+
* `motion_prompt`: Prompt de movimento (string).
|
| 101 |
+
* `conditioning_items_data`: Lista de objetos `LatentConditioningItem` contendo os tensores latentes de condição.
|
| 102 |
+
* `guidance_scale`, `stg_scale`, `num_inference_steps`, etc.
|
| 103 |
+
* **Output:** `tuple[torch.Tensor, tuple]` (Uma tupla contendo o tensor latente gerado e os valores de padding utilizados).
|
| 104 |
+
|
| 105 |
+
### **mmaudio_helper.py (MMAudioAgent)**
|
| 106 |
+
|
| 107 |
+
* **Propósito:** Especialista em geração de áudio para um determinado fragmento de vídeo.
|
| 108 |
+
* **Singleton Instance:** `mmaudio_agent_singleton`
|
| 109 |
+
* **Construtor:** `MMAudioAgent(workspace_dir: str, device_id: str = None, mmaudio_config_file: str)`
|
| 110 |
+
* Lê `configs/mmaudio_config.yaml`.
|
| 111 |
+
* **Método Público:**
|
| 112 |
+
* `generate_audio_for_video(video_path: str, prompt: str, negative_prompt: str, duration_seconds: float)`
|
| 113 |
+
* **Inputs:**
|
| 114 |
+
* `video_path`: Caminho para o arquivo de vídeo silencioso (string).
|
| 115 |
+
* `prompt`: Prompt textual para guiar a geração de áudio (string).
|
| 116 |
+
* `negative_prompt`: Prompt negativo para áudio (string).
|
| 117 |
+
* `duration_seconds`: Duração exata do vídeo (float).
|
| 118 |
+
* **Output:** `str` (O caminho para o novo arquivo de vídeo com a faixa de áudio integrada).
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## 🔗 Projetos Originais e Atribuições
|
| 123 |
+
(A seção de atribuições e licenças permanece a mesma que definimos anteriormente)
|
| 124 |
+
|
| 125 |
+
### DreamO
|
| 126 |
+
* **Repositório Original:** [https://github.com/bytedance/DreamO](https://github.com/bytedance/DreamO)
|
| 127 |
+
...
|
| 128 |
+
|
| 129 |
+
### LTX-Video
|
| 130 |
+
* **Repositório Original:** [https://github.com/Lightricks/LTX-Video](https://github.com/Lightricks/LTX-Video)
|
| 131 |
+
...
|
| 132 |
+
|
| 133 |
+
### MMAudio
|
| 134 |
+
* **Repositório Original:** [https://github.com/hkchengrex/MMAudio](https://github.com/hkchengrex/MMAudio)
|
| 135 |
+
...
|
dreamo/dreamo_pipeline.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 17 |
+
|
| 18 |
+
import diffusers
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
from diffusers import FluxPipeline
|
| 23 |
+
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
|
| 24 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 25 |
+
from einops import repeat
|
| 26 |
+
from huggingface_hub import hf_hub_download
|
| 27 |
+
from safetensors.torch import load_file
|
| 28 |
+
|
| 29 |
+
from dreamo.transformer import flux_transformer_forward
|
| 30 |
+
from dreamo.utils import convert_flux_lora_to_diffusers
|
| 31 |
+
|
| 32 |
+
diffusers.models.transformers.transformer_flux.FluxTransformer2DModel.forward = flux_transformer_forward
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_task_embedding_idx(task):
|
| 36 |
+
return 0
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DreamOPipeline(FluxPipeline):
|
| 40 |
+
def __init__(self, scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer):
|
| 41 |
+
super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
|
| 42 |
+
self.t5_embedding = nn.Embedding(10, 4096)
|
| 43 |
+
self.task_embedding = nn.Embedding(2, 3072)
|
| 44 |
+
self.idx_embedding = nn.Embedding(10, 3072)
|
| 45 |
+
|
| 46 |
+
def load_dreamo_model(self, device, use_turbo=True, version='v1.1'):
|
| 47 |
+
# download models and load file
|
| 48 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
|
| 49 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
|
| 50 |
+
if version == 'v1':
|
| 51 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_pos.safetensors',
|
| 52 |
+
local_dir='models')
|
| 53 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_neg.safetensors',
|
| 54 |
+
local_dir='models')
|
| 55 |
+
quality_lora_pos = load_file('models/dreamo_quality_lora_pos.safetensors')
|
| 56 |
+
quality_lora_neg = load_file('models/dreamo_quality_lora_neg.safetensors')
|
| 57 |
+
elif version == 'v1.1':
|
| 58 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='v1.1/dreamo_sft_lora.safetensors', local_dir='models')
|
| 59 |
+
hf_hub_download(repo_id='ByteDance/DreamO', filename='v1.1/dreamo_dpo_lora.safetensors', local_dir='models')
|
| 60 |
+
sft_lora = load_file('models/v1.1/dreamo_sft_lora.safetensors')
|
| 61 |
+
dpo_lora = load_file('models/v1.1/dreamo_dpo_lora.safetensors')
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f'there is no {version}')
|
| 64 |
+
dreamo_lora = load_file('models/dreamo.safetensors')
|
| 65 |
+
cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
|
| 66 |
+
|
| 67 |
+
# load embedding
|
| 68 |
+
self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
|
| 69 |
+
self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight')
|
| 70 |
+
self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight')
|
| 71 |
+
self._prepare_t5()
|
| 72 |
+
|
| 73 |
+
# main lora
|
| 74 |
+
dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora)
|
| 75 |
+
adapter_names = ['dreamo']
|
| 76 |
+
adapter_weights = [1]
|
| 77 |
+
self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo')
|
| 78 |
+
|
| 79 |
+
# cfg lora to avoid true image cfg
|
| 80 |
+
cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora)
|
| 81 |
+
self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg')
|
| 82 |
+
adapter_names.append('cfg')
|
| 83 |
+
adapter_weights.append(1)
|
| 84 |
+
|
| 85 |
+
# turbo lora to speed up (from 25+ step to 12 step)
|
| 86 |
+
if use_turbo:
|
| 87 |
+
self.load_lora_weights(
|
| 88 |
+
hf_hub_download(
|
| 89 |
+
"alimama-creative/FLUX.1-Turbo-Alpha", "diffusion_pytorch_model.safetensors", local_dir='models'
|
| 90 |
+
),
|
| 91 |
+
adapter_name='turbo',
|
| 92 |
+
)
|
| 93 |
+
adapter_names.append('turbo')
|
| 94 |
+
adapter_weights.append(1)
|
| 95 |
+
|
| 96 |
+
if version == 'v1':
|
| 97 |
+
# quality loras, one pos, one neg
|
| 98 |
+
quality_lora_pos = convert_flux_lora_to_diffusers(quality_lora_pos)
|
| 99 |
+
self.load_lora_weights(quality_lora_pos, adapter_name='quality_pos')
|
| 100 |
+
adapter_names.append('quality_pos')
|
| 101 |
+
adapter_weights.append(0.15)
|
| 102 |
+
quality_lora_neg = convert_flux_lora_to_diffusers(quality_lora_neg)
|
| 103 |
+
self.load_lora_weights(quality_lora_neg, adapter_name='quality_neg')
|
| 104 |
+
adapter_names.append('quality_neg')
|
| 105 |
+
adapter_weights.append(-0.8)
|
| 106 |
+
elif version == 'v1.1':
|
| 107 |
+
self.load_lora_weights(sft_lora, adapter_name='sft_lora')
|
| 108 |
+
adapter_names.append('sft_lora')
|
| 109 |
+
adapter_weights.append(1)
|
| 110 |
+
self.load_lora_weights(dpo_lora, adapter_name='dpo_lora')
|
| 111 |
+
adapter_names.append('dpo_lora')
|
| 112 |
+
adapter_weights.append(1.25)
|
| 113 |
+
|
| 114 |
+
self.set_adapters(adapter_names, adapter_weights)
|
| 115 |
+
self.fuse_lora(adapter_names=adapter_names, lora_scale=1)
|
| 116 |
+
self.unload_lora_weights()
|
| 117 |
+
|
| 118 |
+
self.t5_embedding = self.t5_embedding.to(device)
|
| 119 |
+
self.task_embedding = self.task_embedding.to(device)
|
| 120 |
+
self.idx_embedding = self.idx_embedding.to(device)
|
| 121 |
+
|
| 122 |
+
def _prepare_t5(self):
|
| 123 |
+
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
|
| 124 |
+
num_new_token = 10
|
| 125 |
+
new_token_list = [f"[ref#{i}]" for i in range(1, 10)] + ["[res]"]
|
| 126 |
+
self.tokenizer_2.add_tokens(new_token_list, special_tokens=False)
|
| 127 |
+
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
|
| 128 |
+
input_embedding = self.text_encoder_2.get_input_embeddings().weight.data
|
| 129 |
+
input_embedding[-num_new_token:] = self.t5_embedding.weight.data
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
|
| 133 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 134 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + start_height
|
| 135 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + start_width
|
| 136 |
+
|
| 137 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 138 |
+
|
| 139 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 140 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 141 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 145 |
+
|
| 146 |
+
@staticmethod
|
| 147 |
+
def _prepare_style_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
|
| 148 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 149 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + start_height
|
| 150 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + start_width
|
| 151 |
+
|
| 152 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 153 |
+
|
| 154 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 155 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 156 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 160 |
+
|
| 161 |
+
@torch.no_grad()
|
| 162 |
+
def __call__(
|
| 163 |
+
self,
|
| 164 |
+
prompt: Union[str, List[str]] = None,
|
| 165 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 166 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 167 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 168 |
+
true_cfg_scale: float = 1.0,
|
| 169 |
+
true_cfg_start_step: int = 1,
|
| 170 |
+
true_cfg_end_step: int = 1,
|
| 171 |
+
height: Optional[int] = None,
|
| 172 |
+
width: Optional[int] = None,
|
| 173 |
+
num_inference_steps: int = 28,
|
| 174 |
+
sigmas: Optional[List[float]] = None,
|
| 175 |
+
guidance_scale: float = 3.5,
|
| 176 |
+
neg_guidance_scale: float = 3.5,
|
| 177 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 178 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 179 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 180 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 181 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 182 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 183 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 184 |
+
output_type: Optional[str] = "pil",
|
| 185 |
+
return_dict: bool = True,
|
| 186 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 187 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 188 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 189 |
+
max_sequence_length: int = 512,
|
| 190 |
+
ref_conds=None,
|
| 191 |
+
first_step_guidance_scale=3.5,
|
| 192 |
+
):
|
| 193 |
+
r"""
|
| 194 |
+
Function invoked when calling the pipeline for generation.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 198 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 199 |
+
instead.
|
| 200 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 201 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 202 |
+
will be used instead.
|
| 203 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 204 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 205 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
| 206 |
+
not greater than `1`).
|
| 207 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 208 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
| 209 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
| 210 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
| 211 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
| 212 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 213 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 214 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 215 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
| 216 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 217 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 218 |
+
expense of slower inference.
|
| 219 |
+
sigmas (`List[float]`, *optional*):
|
| 220 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 221 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 222 |
+
will be used.
|
| 223 |
+
guidance_scale (`float`, *optional*, defaults to 3.5):
|
| 224 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 225 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 226 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 227 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 228 |
+
usually at the expense of lower image quality.
|
| 229 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 230 |
+
The number of images to generate per prompt.
|
| 231 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 232 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 233 |
+
to make generation deterministic.
|
| 234 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 235 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 236 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 237 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 238 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 239 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 240 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 241 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 242 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
| 243 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
| 244 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 245 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 246 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 247 |
+
argument.
|
| 248 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 249 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 250 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
| 251 |
+
input argument.
|
| 252 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 253 |
+
The output format of the generate image. Choose between
|
| 254 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 255 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 256 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
| 257 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 258 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 259 |
+
`self.processor` in
|
| 260 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 261 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 262 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 263 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 264 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 265 |
+
`callback_on_step_end_tensor_inputs`.
|
| 266 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 267 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 268 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 269 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 270 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
| 271 |
+
|
| 272 |
+
Examples:
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 276 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 277 |
+
images.
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 281 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 282 |
+
|
| 283 |
+
# 1. Check inputs. Raise error if not correct
|
| 284 |
+
self.check_inputs(
|
| 285 |
+
prompt,
|
| 286 |
+
prompt_2,
|
| 287 |
+
height,
|
| 288 |
+
width,
|
| 289 |
+
prompt_embeds=prompt_embeds,
|
| 290 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 291 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 292 |
+
max_sequence_length=max_sequence_length,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self._guidance_scale = guidance_scale
|
| 296 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 297 |
+
self._current_timestep = None
|
| 298 |
+
self._interrupt = False
|
| 299 |
+
|
| 300 |
+
# 2. Define call parameters
|
| 301 |
+
if prompt is not None and isinstance(prompt, str):
|
| 302 |
+
batch_size = 1
|
| 303 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 304 |
+
batch_size = len(prompt)
|
| 305 |
+
else:
|
| 306 |
+
batch_size = prompt_embeds.shape[0]
|
| 307 |
+
|
| 308 |
+
device = self._execution_device
|
| 309 |
+
|
| 310 |
+
lora_scale = (
|
| 311 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
| 312 |
+
)
|
| 313 |
+
has_neg_prompt = negative_prompt is not None or (
|
| 314 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
| 315 |
+
)
|
| 316 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
| 317 |
+
(
|
| 318 |
+
prompt_embeds,
|
| 319 |
+
pooled_prompt_embeds,
|
| 320 |
+
text_ids,
|
| 321 |
+
) = self.encode_prompt(
|
| 322 |
+
prompt=prompt,
|
| 323 |
+
prompt_2=prompt_2,
|
| 324 |
+
prompt_embeds=prompt_embeds,
|
| 325 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 326 |
+
device=device,
|
| 327 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 328 |
+
max_sequence_length=max_sequence_length,
|
| 329 |
+
lora_scale=lora_scale,
|
| 330 |
+
)
|
| 331 |
+
if do_true_cfg:
|
| 332 |
+
(
|
| 333 |
+
negative_prompt_embeds,
|
| 334 |
+
negative_pooled_prompt_embeds,
|
| 335 |
+
_,
|
| 336 |
+
) = self.encode_prompt(
|
| 337 |
+
prompt=negative_prompt,
|
| 338 |
+
prompt_2=negative_prompt_2,
|
| 339 |
+
prompt_embeds=negative_prompt_embeds,
|
| 340 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 341 |
+
device=device,
|
| 342 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 343 |
+
max_sequence_length=max_sequence_length,
|
| 344 |
+
lora_scale=lora_scale,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# 4. Prepare latent variables
|
| 348 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 349 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 350 |
+
batch_size * num_images_per_prompt,
|
| 351 |
+
num_channels_latents,
|
| 352 |
+
height,
|
| 353 |
+
width,
|
| 354 |
+
prompt_embeds.dtype,
|
| 355 |
+
device,
|
| 356 |
+
generator,
|
| 357 |
+
latents,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# 4.1 concat ref tokens to latent
|
| 361 |
+
origin_img_len = latents.shape[1]
|
| 362 |
+
embeddings = repeat(self.task_embedding.weight[1], "c -> n l c", n=batch_size, l=origin_img_len)
|
| 363 |
+
ref_latents = []
|
| 364 |
+
ref_latent_image_idss = []
|
| 365 |
+
start_height = height // 16
|
| 366 |
+
start_width = width // 16
|
| 367 |
+
for ref_cond in ref_conds:
|
| 368 |
+
img = ref_cond['img'] # [b, 3, h, w], range [-1, 1]
|
| 369 |
+
task = ref_cond['task']
|
| 370 |
+
idx = ref_cond['idx']
|
| 371 |
+
|
| 372 |
+
# encode ref with VAE
|
| 373 |
+
img = img.to(latents)
|
| 374 |
+
ref_latent = self.vae.encode(img).latent_dist.sample()
|
| 375 |
+
ref_latent = (ref_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 376 |
+
cur_height = ref_latent.shape[2]
|
| 377 |
+
cur_width = ref_latent.shape[3]
|
| 378 |
+
ref_latent = self._pack_latents(ref_latent, batch_size, num_channels_latents, cur_height, cur_width)
|
| 379 |
+
ref_latent_image_ids = self._prepare_latent_image_ids(
|
| 380 |
+
batch_size, cur_height, cur_width, device, prompt_embeds.dtype, start_height, start_width
|
| 381 |
+
)
|
| 382 |
+
start_height += cur_height // 2
|
| 383 |
+
start_width += cur_width // 2
|
| 384 |
+
|
| 385 |
+
# prepare task_idx_embedding
|
| 386 |
+
task_idx = get_task_embedding_idx(task)
|
| 387 |
+
cur_task_embedding = repeat(
|
| 388 |
+
self.task_embedding.weight[task_idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
|
| 389 |
+
)
|
| 390 |
+
cur_idx_embedding = repeat(
|
| 391 |
+
self.idx_embedding.weight[idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
|
| 392 |
+
)
|
| 393 |
+
cur_embedding = cur_task_embedding + cur_idx_embedding
|
| 394 |
+
|
| 395 |
+
# concat ref to latent
|
| 396 |
+
embeddings = torch.cat([embeddings, cur_embedding], dim=1)
|
| 397 |
+
ref_latents.append(ref_latent)
|
| 398 |
+
ref_latent_image_idss.append(ref_latent_image_ids)
|
| 399 |
+
|
| 400 |
+
# 5. Prepare timesteps
|
| 401 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
| 402 |
+
image_seq_len = latents.shape[1]
|
| 403 |
+
mu = calculate_shift(
|
| 404 |
+
image_seq_len,
|
| 405 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
| 406 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
| 407 |
+
self.scheduler.config.get("base_shift", 0.5),
|
| 408 |
+
self.scheduler.config.get("max_shift", 1.15),
|
| 409 |
+
)
|
| 410 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 411 |
+
self.scheduler,
|
| 412 |
+
num_inference_steps,
|
| 413 |
+
device,
|
| 414 |
+
sigmas=sigmas,
|
| 415 |
+
mu=mu,
|
| 416 |
+
)
|
| 417 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 418 |
+
self._num_timesteps = len(timesteps)
|
| 419 |
+
|
| 420 |
+
# handle guidance
|
| 421 |
+
if self.transformer.config.guidance_embeds:
|
| 422 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
| 423 |
+
guidance = guidance.expand(latents.shape[0])
|
| 424 |
+
else:
|
| 425 |
+
guidance = None
|
| 426 |
+
neg_guidance = torch.full([1], neg_guidance_scale, device=device, dtype=torch.float32)
|
| 427 |
+
neg_guidance = neg_guidance.expand(latents.shape[0])
|
| 428 |
+
first_step_guidance = torch.full([1], first_step_guidance_scale, device=device, dtype=torch.float32)
|
| 429 |
+
|
| 430 |
+
if self.joint_attention_kwargs is None:
|
| 431 |
+
self._joint_attention_kwargs = {}
|
| 432 |
+
|
| 433 |
+
# 6. Denoising loop
|
| 434 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 435 |
+
for i, t in enumerate(timesteps):
|
| 436 |
+
if self.interrupt:
|
| 437 |
+
continue
|
| 438 |
+
|
| 439 |
+
self._current_timestep = t
|
| 440 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 441 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
| 442 |
+
|
| 443 |
+
noise_pred = self.transformer(
|
| 444 |
+
hidden_states=torch.cat((latents, *ref_latents), dim=1),
|
| 445 |
+
timestep=timestep / 1000,
|
| 446 |
+
guidance=guidance if i > 0 else first_step_guidance,
|
| 447 |
+
pooled_projections=pooled_prompt_embeds,
|
| 448 |
+
encoder_hidden_states=prompt_embeds,
|
| 449 |
+
txt_ids=text_ids,
|
| 450 |
+
img_ids=torch.cat((latent_image_ids, *ref_latent_image_idss), dim=1),
|
| 451 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 452 |
+
return_dict=False,
|
| 453 |
+
embeddings=embeddings,
|
| 454 |
+
)[0][:, :origin_img_len]
|
| 455 |
+
|
| 456 |
+
if do_true_cfg and i >= true_cfg_start_step and i < true_cfg_end_step:
|
| 457 |
+
neg_noise_pred = self.transformer(
|
| 458 |
+
hidden_states=latents,
|
| 459 |
+
timestep=timestep / 1000,
|
| 460 |
+
guidance=neg_guidance,
|
| 461 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
| 462 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 463 |
+
txt_ids=text_ids,
|
| 464 |
+
img_ids=latent_image_ids,
|
| 465 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
| 466 |
+
return_dict=False,
|
| 467 |
+
)[0]
|
| 468 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
| 469 |
+
|
| 470 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 471 |
+
latents_dtype = latents.dtype
|
| 472 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 473 |
+
|
| 474 |
+
if latents.dtype != latents_dtype and torch.backends.mps.is_available():
|
| 475 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 476 |
+
latents = latents.to(latents_dtype)
|
| 477 |
+
|
| 478 |
+
if callback_on_step_end is not None:
|
| 479 |
+
callback_kwargs = {}
|
| 480 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 481 |
+
callback_kwargs[k] = locals()[k]
|
| 482 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 483 |
+
|
| 484 |
+
latents = callback_outputs.pop("latents", latents)
|
| 485 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 486 |
+
|
| 487 |
+
# call the callback, if provided
|
| 488 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 489 |
+
progress_bar.update()
|
| 490 |
+
|
| 491 |
+
self._current_timestep = None
|
| 492 |
+
|
| 493 |
+
if output_type == "latent":
|
| 494 |
+
image = latents
|
| 495 |
+
else:
|
| 496 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 497 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 498 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 499 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 500 |
+
|
| 501 |
+
# Offload all models
|
| 502 |
+
self.maybe_free_model_hooks()
|
| 503 |
+
|
| 504 |
+
if not return_dict:
|
| 505 |
+
return (image,)
|
| 506 |
+
|
| 507 |
+
return FluxPipelineOutput(images=image)
|
dreamo/transformer.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 21 |
+
from diffusers.utils import (
|
| 22 |
+
USE_PEFT_BACKEND,
|
| 23 |
+
logging,
|
| 24 |
+
scale_lora_layers,
|
| 25 |
+
unscale_lora_layers,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def flux_transformer_forward(
|
| 32 |
+
self,
|
| 33 |
+
hidden_states: torch.Tensor,
|
| 34 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 35 |
+
pooled_projections: torch.Tensor = None,
|
| 36 |
+
timestep: torch.LongTensor = None,
|
| 37 |
+
img_ids: torch.Tensor = None,
|
| 38 |
+
txt_ids: torch.Tensor = None,
|
| 39 |
+
guidance: torch.Tensor = None,
|
| 40 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 41 |
+
controlnet_block_samples=None,
|
| 42 |
+
controlnet_single_block_samples=None,
|
| 43 |
+
return_dict: bool = True,
|
| 44 |
+
controlnet_blocks_repeat: bool = False,
|
| 45 |
+
embeddings: torch.Tensor = None,
|
| 46 |
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
| 47 |
+
"""
|
| 48 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 52 |
+
Input `hidden_states`.
|
| 53 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 54 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 55 |
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
| 56 |
+
from the embeddings of input conditions.
|
| 57 |
+
timestep ( `torch.LongTensor`):
|
| 58 |
+
Used to indicate denoising step.
|
| 59 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 60 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 61 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 62 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 63 |
+
`self.processor` in
|
| 64 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 65 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 67 |
+
tuple.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 71 |
+
`tuple` where the first element is the sample tensor.
|
| 72 |
+
"""
|
| 73 |
+
if joint_attention_kwargs is not None:
|
| 74 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 75 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 76 |
+
else:
|
| 77 |
+
lora_scale = 1.0
|
| 78 |
+
|
| 79 |
+
if USE_PEFT_BACKEND:
|
| 80 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 81 |
+
scale_lora_layers(self, lora_scale)
|
| 82 |
+
else:
|
| 83 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
| 84 |
+
logger.warning(
|
| 85 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 89 |
+
# add task and idx embedding
|
| 90 |
+
if embeddings is not None:
|
| 91 |
+
hidden_states = hidden_states + embeddings
|
| 92 |
+
|
| 93 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 94 |
+
guidance = guidance.to(hidden_states.dtype) * 1000 if guidance is not None else None
|
| 95 |
+
|
| 96 |
+
temb = (
|
| 97 |
+
self.time_text_embed(timestep, pooled_projections)
|
| 98 |
+
if guidance is None
|
| 99 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
| 100 |
+
)
|
| 101 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 102 |
+
|
| 103 |
+
if txt_ids.ndim == 3:
|
| 104 |
+
# logger.warning(
|
| 105 |
+
# "Passing `txt_ids` 3d torch.Tensor is deprecated."
|
| 106 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 107 |
+
# )
|
| 108 |
+
txt_ids = txt_ids[0]
|
| 109 |
+
if img_ids.ndim == 3:
|
| 110 |
+
# logger.warning(
|
| 111 |
+
# "Passing `img_ids` 3d torch.Tensor is deprecated."
|
| 112 |
+
# "Please remove the batch dimension and pass it as a 2d torch Tensor"
|
| 113 |
+
# )
|
| 114 |
+
img_ids = img_ids[0]
|
| 115 |
+
|
| 116 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
| 117 |
+
image_rotary_emb = self.pos_embed(ids)
|
| 118 |
+
|
| 119 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 120 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 121 |
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
| 122 |
+
block,
|
| 123 |
+
hidden_states,
|
| 124 |
+
encoder_hidden_states,
|
| 125 |
+
temb,
|
| 126 |
+
image_rotary_emb,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
encoder_hidden_states, hidden_states = block(
|
| 131 |
+
hidden_states=hidden_states,
|
| 132 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 133 |
+
temb=temb,
|
| 134 |
+
image_rotary_emb=image_rotary_emb,
|
| 135 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# controlnet residual
|
| 139 |
+
if controlnet_block_samples is not None:
|
| 140 |
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
| 141 |
+
interval_control = int(np.ceil(interval_control))
|
| 142 |
+
# For Xlabs ControlNet.
|
| 143 |
+
if controlnet_blocks_repeat:
|
| 144 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
| 145 |
+
else:
|
| 146 |
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
| 147 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 148 |
+
|
| 149 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 150 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 151 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 152 |
+
block,
|
| 153 |
+
hidden_states,
|
| 154 |
+
temb,
|
| 155 |
+
image_rotary_emb,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
else:
|
| 159 |
+
hidden_states = block(
|
| 160 |
+
hidden_states=hidden_states,
|
| 161 |
+
temb=temb,
|
| 162 |
+
image_rotary_emb=image_rotary_emb,
|
| 163 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# controlnet residual
|
| 167 |
+
if controlnet_single_block_samples is not None:
|
| 168 |
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
| 169 |
+
interval_control = int(np.ceil(interval_control))
|
| 170 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
| 171 |
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 172 |
+
+ controlnet_single_block_samples[index_block // interval_control]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
| 176 |
+
|
| 177 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 178 |
+
output = self.proj_out(hidden_states)
|
| 179 |
+
|
| 180 |
+
if USE_PEFT_BACKEND:
|
| 181 |
+
# remove `lora_scale` from each PEFT layer
|
| 182 |
+
unscale_lora_layers(self, lora_scale)
|
| 183 |
+
|
| 184 |
+
if not return_dict:
|
| 185 |
+
return (output,)
|
| 186 |
+
|
| 187 |
+
return Transformer2DModelOutput(sample=output)
|
dreamo/utils.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import re
|
| 17 |
+
|
| 18 |
+
import cv2
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from torchvision.utils import make_grid
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# from basicsr
|
| 25 |
+
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
| 26 |
+
"""Numpy array to tensor.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
imgs (list[ndarray] | ndarray): Input images.
|
| 30 |
+
bgr2rgb (bool): Whether to change bgr to rgb.
|
| 31 |
+
float32 (bool): Whether to change to float32.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
list[tensor] | tensor: Tensor images. If returned results only have
|
| 35 |
+
one element, just return tensor.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def _totensor(img, bgr2rgb, float32):
|
| 39 |
+
if img.shape[2] == 3 and bgr2rgb:
|
| 40 |
+
if img.dtype == 'float64':
|
| 41 |
+
img = img.astype('float32')
|
| 42 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 43 |
+
img = torch.from_numpy(img.transpose(2, 0, 1))
|
| 44 |
+
if float32:
|
| 45 |
+
img = img.float()
|
| 46 |
+
return img
|
| 47 |
+
|
| 48 |
+
if isinstance(imgs, list):
|
| 49 |
+
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
| 50 |
+
return _totensor(imgs, bgr2rgb, float32)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
|
| 54 |
+
"""Convert torch Tensors into image numpy arrays.
|
| 55 |
+
|
| 56 |
+
After clamping to [min, max], values will be normalized to [0, 1].
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
tensor (Tensor or list[Tensor]): Accept shapes:
|
| 60 |
+
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
|
| 61 |
+
2) 3D Tensor of shape (3/1 x H x W);
|
| 62 |
+
3) 2D Tensor of shape (H x W).
|
| 63 |
+
Tensor channel should be in RGB order.
|
| 64 |
+
rgb2bgr (bool): Whether to change rgb to bgr.
|
| 65 |
+
out_type (numpy type): output types. If ``np.uint8``, transform outputs
|
| 66 |
+
to uint8 type with range [0, 255]; otherwise, float type with
|
| 67 |
+
range [0, 1]. Default: ``np.uint8``.
|
| 68 |
+
min_max (tuple[int]): min and max values for clamp.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
|
| 72 |
+
shape (H x W). The channel order is BGR.
|
| 73 |
+
"""
|
| 74 |
+
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
|
| 75 |
+
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
|
| 76 |
+
|
| 77 |
+
if torch.is_tensor(tensor):
|
| 78 |
+
tensor = [tensor]
|
| 79 |
+
result = []
|
| 80 |
+
for _tensor in tensor:
|
| 81 |
+
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
|
| 82 |
+
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
| 83 |
+
|
| 84 |
+
n_dim = _tensor.dim()
|
| 85 |
+
if n_dim == 4:
|
| 86 |
+
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
|
| 87 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 88 |
+
if rgb2bgr:
|
| 89 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 90 |
+
elif n_dim == 3:
|
| 91 |
+
img_np = _tensor.numpy()
|
| 92 |
+
img_np = img_np.transpose(1, 2, 0)
|
| 93 |
+
if img_np.shape[2] == 1: # gray image
|
| 94 |
+
img_np = np.squeeze(img_np, axis=2)
|
| 95 |
+
else:
|
| 96 |
+
if rgb2bgr:
|
| 97 |
+
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
|
| 98 |
+
elif n_dim == 2:
|
| 99 |
+
img_np = _tensor.numpy()
|
| 100 |
+
else:
|
| 101 |
+
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
|
| 102 |
+
if out_type == np.uint8:
|
| 103 |
+
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
| 104 |
+
img_np = (img_np * 255.0).round()
|
| 105 |
+
img_np = img_np.astype(out_type)
|
| 106 |
+
result.append(img_np)
|
| 107 |
+
if len(result) == 1:
|
| 108 |
+
result = result[0]
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def resize_numpy_image_area(image, area=512 * 512):
|
| 113 |
+
h, w = image.shape[:2]
|
| 114 |
+
k = math.sqrt(area / (h * w))
|
| 115 |
+
h = int(h * k) - (int(h * k) % 16)
|
| 116 |
+
w = int(w * k) - (int(w * k) % 16)
|
| 117 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
| 118 |
+
return image
|
| 119 |
+
|
| 120 |
+
def resize_numpy_image_long(image, long_edge=768):
|
| 121 |
+
h, w = image.shape[:2]
|
| 122 |
+
if max(h, w) <= long_edge:
|
| 123 |
+
return image
|
| 124 |
+
k = long_edge / max(h, w)
|
| 125 |
+
h = int(h * k)
|
| 126 |
+
w = int(w * k)
|
| 127 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
| 128 |
+
return image
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# reference: https://github.com/huggingface/diffusers/pull/9295/files
|
| 132 |
+
def convert_flux_lora_to_diffusers(old_state_dict):
|
| 133 |
+
new_state_dict = {}
|
| 134 |
+
orig_keys = list(old_state_dict.keys())
|
| 135 |
+
|
| 136 |
+
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
| 137 |
+
down_weight = sds_sd.pop(sds_key)
|
| 138 |
+
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
|
| 139 |
+
|
| 140 |
+
# calculate dims if not provided
|
| 141 |
+
num_splits = len(ait_keys)
|
| 142 |
+
if dims is None:
|
| 143 |
+
dims = [up_weight.shape[0] // num_splits] * num_splits
|
| 144 |
+
else:
|
| 145 |
+
assert sum(dims) == up_weight.shape[0]
|
| 146 |
+
|
| 147 |
+
# make ai-toolkit weight
|
| 148 |
+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
| 149 |
+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
| 150 |
+
|
| 151 |
+
# down_weight is copied to each split
|
| 152 |
+
ait_sd.update({k: down_weight for k in ait_down_keys})
|
| 153 |
+
|
| 154 |
+
# up_weight is split to each split
|
| 155 |
+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
| 156 |
+
|
| 157 |
+
for old_key in orig_keys:
|
| 158 |
+
# Handle double_blocks
|
| 159 |
+
if 'double_blocks' in old_key:
|
| 160 |
+
block_num = re.search(r"double_blocks_(\d+)", old_key).group(1)
|
| 161 |
+
new_key = f"transformer.transformer_blocks.{block_num}"
|
| 162 |
+
|
| 163 |
+
if "proj_lora1" in old_key:
|
| 164 |
+
new_key += ".attn.to_out.0"
|
| 165 |
+
elif "proj_lora2" in old_key:
|
| 166 |
+
new_key += ".attn.to_add_out"
|
| 167 |
+
elif "qkv_lora2" in old_key and "up" not in old_key:
|
| 168 |
+
handle_qkv(
|
| 169 |
+
old_state_dict,
|
| 170 |
+
new_state_dict,
|
| 171 |
+
old_key,
|
| 172 |
+
[
|
| 173 |
+
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
|
| 174 |
+
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
|
| 175 |
+
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
|
| 176 |
+
],
|
| 177 |
+
)
|
| 178 |
+
# continue
|
| 179 |
+
elif "qkv_lora1" in old_key and "up" not in old_key:
|
| 180 |
+
handle_qkv(
|
| 181 |
+
old_state_dict,
|
| 182 |
+
new_state_dict,
|
| 183 |
+
old_key,
|
| 184 |
+
[
|
| 185 |
+
f"transformer.transformer_blocks.{block_num}.attn.to_q",
|
| 186 |
+
f"transformer.transformer_blocks.{block_num}.attn.to_k",
|
| 187 |
+
f"transformer.transformer_blocks.{block_num}.attn.to_v",
|
| 188 |
+
],
|
| 189 |
+
)
|
| 190 |
+
# continue
|
| 191 |
+
|
| 192 |
+
if "down" in old_key:
|
| 193 |
+
new_key += ".lora_A.weight"
|
| 194 |
+
elif "up" in old_key:
|
| 195 |
+
new_key += ".lora_B.weight"
|
| 196 |
+
|
| 197 |
+
# Handle single_blocks
|
| 198 |
+
elif 'single_blocks' in old_key:
|
| 199 |
+
block_num = re.search(r"single_blocks_(\d+)", old_key).group(1)
|
| 200 |
+
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
| 201 |
+
|
| 202 |
+
if "proj_lora" in old_key:
|
| 203 |
+
new_key += ".proj_out"
|
| 204 |
+
elif "qkv_lora" in old_key and "up" not in old_key:
|
| 205 |
+
handle_qkv(
|
| 206 |
+
old_state_dict,
|
| 207 |
+
new_state_dict,
|
| 208 |
+
old_key,
|
| 209 |
+
[
|
| 210 |
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
| 211 |
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
| 212 |
+
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
| 213 |
+
],
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if "down" in old_key:
|
| 217 |
+
new_key += ".lora_A.weight"
|
| 218 |
+
elif "up" in old_key:
|
| 219 |
+
new_key += ".lora_B.weight"
|
| 220 |
+
|
| 221 |
+
else:
|
| 222 |
+
# Handle other potential key patterns here
|
| 223 |
+
new_key = old_key
|
| 224 |
+
|
| 225 |
+
# Since we already handle qkv above.
|
| 226 |
+
if "qkv" not in old_key and 'embedding' not in old_key:
|
| 227 |
+
new_state_dict[new_key] = old_state_dict.pop(old_key)
|
| 228 |
+
|
| 229 |
+
# if len(old_state_dict) > 0:
|
| 230 |
+
# raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
| 231 |
+
|
| 232 |
+
return new_state_dict
|
flux_kontext_helpers.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flux_kontext_helpers.py (ADUC: O Especialista Pintor - com suporte a callback)
|
| 2 |
+
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image, ImageOps
|
| 6 |
+
import gc
|
| 7 |
+
from diffusers import FluxKontextPipeline
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
import os
|
| 10 |
+
import threading
|
| 11 |
+
import yaml
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from hardware_manager import hardware_manager
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class FluxWorker:
|
| 19 |
+
"""Representa uma única instância do pipeline FluxKontext em um dispositivo."""
|
| 20 |
+
def __init__(self, device_id='cuda:0'):
|
| 21 |
+
self.cpu_device = torch.device('cpu')
|
| 22 |
+
self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
|
| 23 |
+
self.pipe = None
|
| 24 |
+
self._load_pipe_to_cpu()
|
| 25 |
+
|
| 26 |
+
def _load_pipe_to_cpu(self):
|
| 27 |
+
if self.pipe is None:
|
| 28 |
+
logger.info(f"FLUX Worker ({self.device}): Carregando modelo para a CPU...")
|
| 29 |
+
self.pipe = FluxKontextPipeline.from_pretrained(
|
| 30 |
+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
|
| 31 |
+
).to(self.cpu_device)
|
| 32 |
+
logger.info(f"FLUX Worker ({self.device}): Modelo pronto na CPU.")
|
| 33 |
+
|
| 34 |
+
def to_gpu(self):
|
| 35 |
+
if self.device.type == 'cpu': return
|
| 36 |
+
logger.info(f"FLUX Worker: Movendo modelo para a GPU {self.device}...")
|
| 37 |
+
self.pipe.to(self.device)
|
| 38 |
+
|
| 39 |
+
def to_cpu(self):
|
| 40 |
+
if self.device.type == 'cpu': return
|
| 41 |
+
logger.info(f"FLUX Worker: Descarregando modelo da GPU {self.device}...")
|
| 42 |
+
self.pipe.to(self.cpu_device)
|
| 43 |
+
gc.collect()
|
| 44 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 45 |
+
|
| 46 |
+
def _create_composite_reference(self, images: list[Image.Image], target_width: int, target_height: int) -> Image.Image:
|
| 47 |
+
if not images: return None
|
| 48 |
+
valid_images = [img.convert("RGB") for img in images if img is not None]
|
| 49 |
+
if not valid_images: return None
|
| 50 |
+
if len(valid_images) == 1:
|
| 51 |
+
if valid_images[0].size != (target_width, target_height):
|
| 52 |
+
return ImageOps.fit(valid_images[0], (target_width, target_height), Image.Resampling.LANCZOS)
|
| 53 |
+
return valid_images[0]
|
| 54 |
+
|
| 55 |
+
base_height = valid_images[0].height
|
| 56 |
+
resized_for_concat = []
|
| 57 |
+
for img in valid_images:
|
| 58 |
+
if img.height != base_height:
|
| 59 |
+
aspect_ratio = img.width / img.height
|
| 60 |
+
new_width = int(base_height * aspect_ratio)
|
| 61 |
+
resized_for_concat.append(img.resize((new_width, base_height), Image.Resampling.LANCZOS))
|
| 62 |
+
else:
|
| 63 |
+
resized_for_concat.append(img)
|
| 64 |
+
|
| 65 |
+
total_width = sum(img.width for img in resized_for_concat)
|
| 66 |
+
concatenated = Image.new('RGB', (total_width, base_height))
|
| 67 |
+
x_offset = 0
|
| 68 |
+
for img in resized_for_concat:
|
| 69 |
+
concatenated.paste(img, (x_offset, 0))
|
| 70 |
+
x_offset += img.width
|
| 71 |
+
|
| 72 |
+
final_reference = ImageOps.fit(concatenated, (target_width, target_height), Image.Resampling.LANCZOS)
|
| 73 |
+
return final_reference
|
| 74 |
+
|
| 75 |
+
@torch.inference_mode()
|
| 76 |
+
def generate_image_internal(self, reference_images: list[Image.Image], prompt: str, target_width: int, target_height: int, seed: int, callback: callable = None):
|
| 77 |
+
composite_reference = self._create_composite_reference(reference_images, target_width, target_height)
|
| 78 |
+
|
| 79 |
+
num_steps = 30 # Valor fixo otimizado
|
| 80 |
+
|
| 81 |
+
logger.info(f"\n===== [CHAMADA AO PIPELINE FLUX em {self.device}] =====\n"
|
| 82 |
+
f" - Prompt: '{prompt}'\n"
|
| 83 |
+
f" - Resolução: {target_width}x{target_height}, Seed: {seed}, Passos: {num_steps}\n"
|
| 84 |
+
f" - Nº de Imagens na Composição: {len(reference_images)}\n"
|
| 85 |
+
f"==========================================")
|
| 86 |
+
|
| 87 |
+
generated_image = self.pipe(
|
| 88 |
+
image=composite_reference,
|
| 89 |
+
prompt=prompt,
|
| 90 |
+
guidance_scale=2.5,
|
| 91 |
+
width=target_width,
|
| 92 |
+
height=target_height,
|
| 93 |
+
num_inference_steps=num_steps,
|
| 94 |
+
generator=torch.Generator(device="cpu").manual_seed(seed),
|
| 95 |
+
callback_on_step_end=callback,
|
| 96 |
+
callback_on_step_end_tensor_inputs=["latents"] if callback else None
|
| 97 |
+
).images[0]
|
| 98 |
+
|
| 99 |
+
return generated_image
|
| 100 |
+
|
| 101 |
+
class FluxPoolManager:
|
| 102 |
+
def __init__(self, device_ids):
|
| 103 |
+
logger.info(f"FLUX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
|
| 104 |
+
self.workers = [FluxWorker(device_id) for device_id in device_ids]
|
| 105 |
+
self.current_worker_index = 0
|
| 106 |
+
self.lock = threading.Lock()
|
| 107 |
+
self.last_cleanup_thread = None
|
| 108 |
+
|
| 109 |
+
def _cleanup_worker_thread(self, worker):
|
| 110 |
+
logger.info(f"FLUX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
|
| 111 |
+
worker.to_cpu()
|
| 112 |
+
|
| 113 |
+
def generate_image(self, reference_images, prompt, width, height, seed=42, callback=None):
|
| 114 |
+
worker_to_use = None
|
| 115 |
+
try:
|
| 116 |
+
with self.lock:
|
| 117 |
+
if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
|
| 118 |
+
self.last_cleanup_thread.join()
|
| 119 |
+
worker_to_use = self.workers[self.current_worker_index]
|
| 120 |
+
previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
|
| 121 |
+
worker_to_cleanup = self.workers[previous_worker_index]
|
| 122 |
+
cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
|
| 123 |
+
cleanup_thread.start()
|
| 124 |
+
self.last_cleanup_thread = cleanup_thread
|
| 125 |
+
worker_to_use.to_gpu()
|
| 126 |
+
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
|
| 127 |
+
|
| 128 |
+
logger.info(f"FLUX POOL MANAGER: Gerando imagem em {worker_to_use.device}...")
|
| 129 |
+
return worker_to_use.generate_image_internal(
|
| 130 |
+
reference_images=reference_images,
|
| 131 |
+
prompt=prompt,
|
| 132 |
+
target_width=width,
|
| 133 |
+
target_height=height,
|
| 134 |
+
seed=seed,
|
| 135 |
+
callback=callback
|
| 136 |
+
)
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"FLUX POOL MANAGER: Erro durante a geração: {e}", exc_info=True)
|
| 139 |
+
raise e
|
| 140 |
+
finally:
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
# --- Instanciação Singleton Dinâmica ---
|
| 144 |
+
logger.info("Lendo config.yaml para inicializar o FluxKontext Pool Manager...")
|
| 145 |
+
with open("config.yaml", 'r') as f: config = yaml.safe_load(f)
|
| 146 |
+
hf_token = os.getenv('HF_TOKEN');
|
| 147 |
+
if hf_token: huggingface_hub.login(token=hf_token)
|
| 148 |
+
flux_gpus_required = config['specialists']['flux']['gpus_required']
|
| 149 |
+
flux_device_ids = hardware_manager.allocate_gpus('Flux', flux_gpus_required)
|
| 150 |
+
flux_kontext_singleton = FluxPoolManager(device_ids=flux_device_ids)
|
| 151 |
+
logger.info("Especialista de Imagem (Flux) pronto.")
|
gemini_helpers.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# gemini_helpers.py
|
| 2 |
+
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
|
| 3 |
+
#
|
| 4 |
+
# Este programa é software livre: você pode redistribuí-lo e/ou modificá-lo
|
| 5 |
+
# sob os termos da Licença Pública Geral Affero GNU como publicada pela
|
| 6 |
+
# Free Software Foundation, seja a versão 3 da Licença, ou
|
| 7 |
+
# (a seu critério) qualquer versão posterior.
|
| 8 |
+
#
|
| 9 |
+
# AVISO DE PATENTE PENDENTE: O método e sistema ADUC implementado neste
|
| 10 |
+
# software está em processo de patenteamento. Consulte NOTICE.md.
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import logging
|
| 14 |
+
import json
|
| 15 |
+
import gradio as gr
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import google.generativeai as genai
|
| 18 |
+
import re
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
def robust_json_parser(raw_text: str) -> dict:
|
| 24 |
+
clean_text = raw_text.strip()
|
| 25 |
+
try:
|
| 26 |
+
# Tenta encontrar o JSON delimitado por ```json ... ```
|
| 27 |
+
match = re.search(r'```json\s*(\{.*?\})\s*```', clean_text, re.DOTALL)
|
| 28 |
+
if match:
|
| 29 |
+
json_str = match.group(1)
|
| 30 |
+
return json.loads(json_str)
|
| 31 |
+
|
| 32 |
+
# Se não encontrar, tenta encontrar o primeiro '{' e o último '}'
|
| 33 |
+
start_index = clean_text.find('{')
|
| 34 |
+
end_index = clean_text.rfind('}')
|
| 35 |
+
if start_index != -1 and end_index != -1 and end_index > start_index:
|
| 36 |
+
json_str = clean_text[start_index : end_index + 1]
|
| 37 |
+
return json.loads(json_str)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError("Nenhum objeto JSON válido foi encontrado na resposta da IA.")
|
| 40 |
+
except json.JSONDecodeError as e:
|
| 41 |
+
logger.error(f"Falha ao decodificar JSON. A IA retornou o seguinte texto:\n---\n{raw_text}\n---")
|
| 42 |
+
raise ValueError(f"A IA retornou um formato de JSON inválido: {e}")
|
| 43 |
+
|
| 44 |
+
class GeminiSingleton:
|
| 45 |
+
def __init__(self):
|
| 46 |
+
self.api_key = os.environ.get("GEMINI_API_KEY")
|
| 47 |
+
if self.api_key:
|
| 48 |
+
genai.configure(api_key=self.api_key)
|
| 49 |
+
# Modelo mais recente e capaz para tarefas complexas de visão e raciocínio.
|
| 50 |
+
self.model = genai.GenerativeModel('gemini-2.5-flash')
|
| 51 |
+
logger.info("Especialista Gemini (1.5 Pro) inicializado com sucesso.")
|
| 52 |
+
else:
|
| 53 |
+
self.model = None
|
| 54 |
+
logger.warning("Chave da API Gemini não encontrada. Especialista desativado.")
|
| 55 |
+
|
| 56 |
+
def _check_model(self):
|
| 57 |
+
if not self.model:
|
| 58 |
+
raise gr.Error("A chave da API do Google Gemini não está configurada (GEMINI_API_KEY).")
|
| 59 |
+
|
| 60 |
+
def _read_prompt_template(self, filename: str) -> str:
|
| 61 |
+
try:
|
| 62 |
+
with open(os.path.join("prompts", filename), "r", encoding="utf-8") as f:
|
| 63 |
+
return f.read()
|
| 64 |
+
except FileNotFoundError:
|
| 65 |
+
raise gr.Error(f"Arquivo de prompt não encontrado: prompts/{filename}")
|
| 66 |
+
|
| 67 |
+
def generate_storyboard(self, prompt: str, num_keyframes: int, ref_image_paths: list[str]) -> list[str]:
|
| 68 |
+
self._check_model()
|
| 69 |
+
try:
|
| 70 |
+
template = self._read_prompt_template("unified_storyboard_prompt.txt")
|
| 71 |
+
storyboard_prompt = template.format(user_prompt=prompt, num_fragments=num_keyframes)
|
| 72 |
+
model_contents = [storyboard_prompt] + [Image.open(p) for p in ref_image_paths]
|
| 73 |
+
response = self.model.generate_content(model_contents)
|
| 74 |
+
|
| 75 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (generate_storyboard) ---\n{response.text}\n--------------------")
|
| 76 |
+
|
| 77 |
+
storyboard_data = robust_json_parser(response.text)
|
| 78 |
+
storyboard = storyboard_data.get("scene_storyboard", [])
|
| 79 |
+
if not storyboard or len(storyboard) != num_keyframes: raise ValueError(f"Número incorreto de cenas gerado.")
|
| 80 |
+
return storyboard
|
| 81 |
+
except Exception as e:
|
| 82 |
+
raise gr.Error(f"O Roteirista (Gemini) falhou: {e}")
|
| 83 |
+
|
| 84 |
+
def select_keyframes_from_pool(self, storyboard: list, base_image_paths: list[str], pool_image_paths: list[str]) -> list[str]:
|
| 85 |
+
self._check_model()
|
| 86 |
+
if not pool_image_paths:
|
| 87 |
+
raise gr.Error("O 'banco de imagens' (Imagens Adicionais) está vazio.")
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
template = self._read_prompt_template("keyframe_selection_prompt.txt")
|
| 91 |
+
|
| 92 |
+
image_map = {f"IMG-{i+1}": path for i, path in enumerate(pool_image_paths)}
|
| 93 |
+
base_image_map = {f"BASE-{i+1}": path for i, path in enumerate(base_image_paths)}
|
| 94 |
+
|
| 95 |
+
model_contents = ["# Reference Images (Story Base)"]
|
| 96 |
+
for identifier, path in base_image_map.items():
|
| 97 |
+
model_contents.extend([f"Identifier: {identifier}", Image.open(path)])
|
| 98 |
+
|
| 99 |
+
model_contents.append("\n# Image Pool (Scene Bank)")
|
| 100 |
+
for identifier, path in image_map.items():
|
| 101 |
+
model_contents.extend([f"Identifier: {identifier}", Image.open(path)])
|
| 102 |
+
|
| 103 |
+
storyboard_str = "\n".join([f"- Scene {i+1}: {s}" for i, s in enumerate(storyboard)])
|
| 104 |
+
selection_prompt = template.format(storyboard_str=storyboard_str, image_identifiers=list(image_map.keys()))
|
| 105 |
+
model_contents.append(selection_prompt)
|
| 106 |
+
|
| 107 |
+
response = self.model.generate_content(model_contents)
|
| 108 |
+
|
| 109 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (select_keyframes_from_pool) ---\n{response.text}\n--------------------")
|
| 110 |
+
|
| 111 |
+
selection_data = robust_json_parser(response.text)
|
| 112 |
+
selected_identifiers = selection_data.get("selected_image_identifiers", [])
|
| 113 |
+
|
| 114 |
+
if len(selected_identifiers) != len(storyboard):
|
| 115 |
+
raise ValueError("A IA não selecionou o número correto de imagens para as cenas.")
|
| 116 |
+
|
| 117 |
+
selected_paths = [image_map[identifier] for identifier in selected_identifiers]
|
| 118 |
+
return selected_paths
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
raise gr.Error(f"O Fotógrafo (Gemini) falhou ao selecionar as imagens: {e}")
|
| 122 |
+
|
| 123 |
+
def get_anticipatory_keyframe_prompt(self, global_prompt: str, scene_history: str, current_scene_desc: str, future_scene_desc: str, last_image_path: str, fixed_ref_paths: list[str]) -> str:
|
| 124 |
+
self._check_model()
|
| 125 |
+
try:
|
| 126 |
+
template = self._read_prompt_template("anticipatory_keyframe_prompt.txt")
|
| 127 |
+
|
| 128 |
+
director_prompt = template.format(
|
| 129 |
+
historico_prompt=scene_history,
|
| 130 |
+
cena_atual=current_scene_desc,
|
| 131 |
+
cena_futura=future_scene_desc
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
model_contents = [
|
| 135 |
+
"# CONTEXTO:",
|
| 136 |
+
f"- Global Story Goal: {global_prompt}",
|
| 137 |
+
"# VISUAL ASSETS:",
|
| 138 |
+
"Current Base Image [IMG-BASE]:",
|
| 139 |
+
Image.open(last_image_path)
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
ref_counter = 1
|
| 143 |
+
for path in fixed_ref_paths:
|
| 144 |
+
if path != last_image_path:
|
| 145 |
+
model_contents.extend([f"General Reference Image [IMG-REF-{ref_counter}]:", Image.open(path)])
|
| 146 |
+
ref_counter += 1
|
| 147 |
+
|
| 148 |
+
model_contents.append(director_prompt)
|
| 149 |
+
|
| 150 |
+
response = self.model.generate_content(model_contents)
|
| 151 |
+
|
| 152 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (get_anticipatory_keyframe_prompt) ---\n{response.text}\n--------------------")
|
| 153 |
+
|
| 154 |
+
final_flux_prompt = response.text.strip()
|
| 155 |
+
return final_flux_prompt
|
| 156 |
+
except Exception as e:
|
| 157 |
+
raise gr.Error(f"O Diretor de Arte (Gemini) falhou: {e}")
|
| 158 |
+
|
| 159 |
+
def get_initial_motion_prompt(self, user_prompt: str, start_image_path: str, destination_image_path: str, dest_scene_desc: str) -> str:
|
| 160 |
+
"""Gera o prompt de movimento para a PRIMEIRA transição, que não tem um 'passado'."""
|
| 161 |
+
self._check_model()
|
| 162 |
+
try:
|
| 163 |
+
template = self._read_prompt_template("initial_motion_prompt.txt")
|
| 164 |
+
prompt_text = template.format(user_prompt=user_prompt, destination_scene_description=dest_scene_desc)
|
| 165 |
+
model_contents = [
|
| 166 |
+
prompt_text,
|
| 167 |
+
"START Image:",
|
| 168 |
+
Image.open(start_image_path),
|
| 169 |
+
"DESTINATION Image:",
|
| 170 |
+
Image.open(destination_image_path)
|
| 171 |
+
]
|
| 172 |
+
response = self.model.generate_content(model_contents)
|
| 173 |
+
|
| 174 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (get_initial_motion_prompt) ---\n{response.text}\n--------------------")
|
| 175 |
+
|
| 176 |
+
return response.text.strip()
|
| 177 |
+
except Exception as e:
|
| 178 |
+
raise gr.Error(f"O Cineasta Inicial (Gemini) falhou: {e}")
|
| 179 |
+
|
| 180 |
+
def get_cinematic_decision(self, global_prompt: str, story_history: str,
|
| 181 |
+
past_keyframe_path: str, present_keyframe_path: str, future_keyframe_path: str,
|
| 182 |
+
past_scene_desc: str, present_scene_desc: str, future_scene_desc: str) -> dict:
|
| 183 |
+
"""
|
| 184 |
+
Atua como um 'Cineasta', analisando passado, presente e futuro para tomar decisões
|
| 185 |
+
de edição e gerar prompts de movimento detalhados.
|
| 186 |
+
"""
|
| 187 |
+
self._check_model()
|
| 188 |
+
try:
|
| 189 |
+
template = self._read_prompt_template("cinematic_director_prompt.txt")
|
| 190 |
+
prompt_text = template.format(
|
| 191 |
+
global_prompt=global_prompt,
|
| 192 |
+
story_history=story_history,
|
| 193 |
+
past_scene_desc=past_scene_desc,
|
| 194 |
+
present_scene_desc=present_scene_desc,
|
| 195 |
+
future_scene_desc=future_scene_desc
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
model_contents = [
|
| 199 |
+
prompt_text,
|
| 200 |
+
"[PAST_IMAGE]:", Image.open(past_keyframe_path),
|
| 201 |
+
"[PRESENT_IMAGE]:", Image.open(present_keyframe_path),
|
| 202 |
+
"[FUTURE_IMAGE]:", Image.open(future_keyframe_path)
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
response = self.model.generate_content(model_contents)
|
| 206 |
+
|
| 207 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (get_cinematic_decision) ---\n{response.text}\n--------------------")
|
| 208 |
+
|
| 209 |
+
decision_data = robust_json_parser(response.text)
|
| 210 |
+
if "transition_type" not in decision_data or "motion_prompt" not in decision_data:
|
| 211 |
+
raise ValueError("Resposta da IA (Cineasta) está mal formatada. Faltam 'transition_type' ou 'motion_prompt'.")
|
| 212 |
+
return decision_data
|
| 213 |
+
except Exception as e:
|
| 214 |
+
# Fallback para uma decisão segura em caso de erro
|
| 215 |
+
logger.error(f"O Diretor de Cinema (Gemini) falhou: {e}. Usando fallback para 'continuous'.")
|
| 216 |
+
return {
|
| 217 |
+
"transition_type": "continuous",
|
| 218 |
+
"motion_prompt": f"A smooth, continuous cinematic transition from '{present_scene_desc}' to '{future_scene_desc}'."
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_sound_director_prompt(self, audio_history: str,
|
| 224 |
+
past_keyframe_path: str, present_keyframe_path: str, future_keyframe_path: str,
|
| 225 |
+
present_scene_desc: str, motion_prompt: str, future_scene_desc: str) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Atua como um 'Diretor de Som', analisando o contexto completo para criar um prompt
|
| 228 |
+
de áudio imersivo e contínuo para a cena atual.
|
| 229 |
+
"""
|
| 230 |
+
self._check_model()
|
| 231 |
+
try:
|
| 232 |
+
template = self._read_prompt_template("sound_director_prompt.txt")
|
| 233 |
+
prompt_text = template.format(
|
| 234 |
+
audio_history=audio_history,
|
| 235 |
+
present_scene_desc=present_scene_desc,
|
| 236 |
+
motion_prompt=motion_prompt,
|
| 237 |
+
future_scene_desc=future_scene_desc
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
model_contents = [
|
| 241 |
+
prompt_text,
|
| 242 |
+
"[PAST_IMAGE]:", Image.open(past_keyframe_path),
|
| 243 |
+
"[PRESENT_IMAGE]:", Image.open(present_keyframe_path),
|
| 244 |
+
"[FUTURE_IMAGE]:", Image.open(future_keyframe_path)
|
| 245 |
+
]
|
| 246 |
+
|
| 247 |
+
response = self.model.generate_content(model_contents)
|
| 248 |
+
|
| 249 |
+
logger.info(f"--- RESPOSTA COMPLETA DO GEMINI (get_sound_director_prompt) ---\n{response.text}\n--------------------")
|
| 250 |
+
|
| 251 |
+
return response.text.strip()
|
| 252 |
+
except Exception as e:
|
| 253 |
+
logger.error(f"O Diretor de Som (Gemini) falhou: {e}. Usando fallback.")
|
| 254 |
+
return f"Sound effects matching the scene: {present_scene_desc}"
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
gemini_singleton = GeminiSingleton()
|
hardware_manager.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hardware_manager.py
|
| 2 |
+
# Gerencia a detecção e alocação de GPUs para os especialistas.
|
| 3 |
+
# Copyright (C) 2025 Carlos Rodrigues dos Santos
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class HardwareManager:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.gpus = []
|
| 13 |
+
self.allocated_gpus = set()
|
| 14 |
+
if torch.cuda.is_available():
|
| 15 |
+
self.gpus = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
|
| 16 |
+
logger.info(f"Hardware Manager: Encontradas {len(self.gpus)} GPUs disponíveis: {self.gpus}")
|
| 17 |
+
|
| 18 |
+
def allocate_gpus(self, specialist_name: str, num_required: int) -> list[str]:
|
| 19 |
+
if not self.gpus or num_required == 0:
|
| 20 |
+
logger.warning(f"Nenhuma GPU disponível ou solicitada para '{specialist_name}'. Alocando para CPU.")
|
| 21 |
+
return ['cpu']
|
| 22 |
+
|
| 23 |
+
available_gpus = [gpu for gpu in self.gpus if gpu not in self.allocated_gpus]
|
| 24 |
+
|
| 25 |
+
if len(available_gpus) < num_required:
|
| 26 |
+
error_msg = f"Recursos de GPU insuficientes para '{specialist_name}'. Solicitado: {num_required}, Disponível: {len(available_gpus)}."
|
| 27 |
+
logger.error(error_msg)
|
| 28 |
+
raise RuntimeError(error_msg)
|
| 29 |
+
|
| 30 |
+
allocated = available_gpus[:num_required]
|
| 31 |
+
self.allocated_gpus.update(allocated)
|
| 32 |
+
logger.info(f"Hardware Manager: Alocando GPUs {allocated} para o especialista '{specialist_name}'.")
|
| 33 |
+
return allocated
|
| 34 |
+
|
| 35 |
+
hardware_manager = HardwareManager()
|
i18n.json
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pt": {
|
| 3 |
+
"app_title": "ADUC-SDR 🎬 - O Diretor de Cinema IA",
|
| 4 |
+
"app_subtitle": "Crie um filme completo com vídeo e áudio, orquestrado por uma equipe de IAs.",
|
| 5 |
+
"lang_selector_label": "Idioma / Language",
|
| 6 |
+
"step1_accordion": "Etapa 1: Roteiro e Cenas-Chave",
|
| 7 |
+
"prompt_label": "Ideia Geral do Filme",
|
| 8 |
+
"ref_images_base_label": "Imagens de Referência (Base da História)",
|
| 9 |
+
"ref_images_extra_label": "Imagens Adicionais (Banco de Cenas para o Modo Fotógrafo)",
|
| 10 |
+
"keyframes_label": "Número de Cenas-Chave",
|
| 11 |
+
"storyboard_button": "1. Gerar Roteiro",
|
| 12 |
+
"storyboard_and_keyframes_button": "1A. Gerar Roteiro e Keyframes (Modo Diretor de Arte)",
|
| 13 |
+
"storyboard_from_photos_button": "1B. Gerar Roteiro a partir de Fotos (Modo Fotógrafo)",
|
| 14 |
+
"step1_mode_b_info": "Modo Fotógrafo: As 'Imagens Adicionais' são usadas como um banco de cenas e a IA escolherá a melhor para cada parte do roteiro.",
|
| 15 |
+
"storyboard_output_label": "Roteiro Gerado (Storyboard)",
|
| 16 |
+
"step2_accordion": "Etapa 2: Os Keyframes (Especialista: Flux)",
|
| 17 |
+
"step2_description": "O Diretor de Arte (Gemini) guiará o Pintor (Flux) para criar as imagens-chave da sua história.",
|
| 18 |
+
"art_director_label": "Usar Diretor de Arte IA (para prompts de keyframe)",
|
| 19 |
+
"keyframes_button": "2. Gerar Imagens-Chave",
|
| 20 |
+
"keyframes_gallery_label": "Galeria de Cenas-Chave (Keyframes)",
|
| 21 |
+
"manual_keyframes_label": "Carregar Keyframes Manualmente",
|
| 22 |
+
"manual_separator": "--- OU ---",
|
| 23 |
+
"step3_accordion": "Etapa 3: A Produção do Filme (Especialistas: LTX & MMAudio)",
|
| 24 |
+
"step3_description": "O Diretor de Continuidade e o Cineasta irão guiar a Câmera (LTX) para filmar as transições entre os keyframes.",
|
| 25 |
+
"continuity_director_label": "Usar Diretor de Continuidade IA (para cortes)",
|
| 26 |
+
"cinematographer_label": "Usar Cineasta IA (para prompts de movimento)",
|
| 27 |
+
"duration_label": "Duração por Cena (s)",
|
| 28 |
+
"n_corte_label": "Ponto de Corte Base (%)",
|
| 29 |
+
"n_corte_info": "Percentual base da cena a ser substituído pela transição. Será ajustado dinamicamente.",
|
| 30 |
+
"convergence_chunks_label": "Máx. Chunks de Convergência",
|
| 31 |
+
"convergence_chunks_info": "Nº máx. de chunks latentes (memória) para guiar a convergência do movimento. Será ajustado dinamicamente.",
|
| 32 |
+
"path_convergence_label": "Força do Handler (Tensor)",
|
| 33 |
+
"destination_convergence_label": "Convergência do Destino (Tensor)",
|
| 34 |
+
"produce_button": "3. 🎬 Produzir Filme Completo (com Som)",
|
| 35 |
+
"advanced_accordion_label": "Configurações Avançadas (LTX)",
|
| 36 |
+
"guidance_label": "Guidance Scale",
|
| 37 |
+
"stg_label": "STG Scale",
|
| 38 |
+
"rescaling_label": "Rescaling Scale",
|
| 39 |
+
"steps_label": "Passos de Inferência",
|
| 40 |
+
"steps_info": "Mais passos podem melhorar a qualidade, mas aumentam o tempo. Ignorado para modelos 'distilled'.",
|
| 41 |
+
"video_fragments_gallery_label": "Fragmentos do Filme Gerados",
|
| 42 |
+
"final_movie_with_audio_label": "🎉 FILME COMPLETO 🎉"
|
| 43 |
+
},
|
| 44 |
+
"en": {
|
| 45 |
+
"app_title": "ADUC-SDR 🎬 - The AI Film Director",
|
| 46 |
+
"app_subtitle": "Create a complete film with video and audio, orchestrated by a team of AIs.",
|
| 47 |
+
"lang_selector_label": "Language / Idioma",
|
| 48 |
+
"step1_accordion": "Step 1: Script & Key Scenes",
|
| 49 |
+
"prompt_label": "General Film Idea",
|
| 50 |
+
"ref_images_base_label": "Reference Images (Story Base)",
|
| 51 |
+
"ref_images_extra_label": "Additional Images (Scene Bank for Photographer Mode)",
|
| 52 |
+
"keyframes_label": "Number of Key-Scenes",
|
| 53 |
+
"storyboard_button": "1. Generate Script",
|
| 54 |
+
"storyboard_and_keyframes_button": "1A. Generate Script & Keyframes (Art Director Mode)",
|
| 55 |
+
"storyboard_from_photos_button": "1B. Generate Script from Photos (Photographer Mode)",
|
| 56 |
+
"step1_mode_b_info": "Photographer Mode: 'Additional Images' are used as a scene bank, and the AI will choose the best one for each script part.",
|
| 57 |
+
"storyboard_output_label": "Generated Script (Storyboard)",
|
| 58 |
+
"step2_accordion": "Step 2: The Keyframes (Specialist: Flux)",
|
| 59 |
+
"step2_description": "The Art Director (Gemini) will guide the Painter (Flux) to create the key images of your story.",
|
| 60 |
+
"art_director_label": "Use AI Art Director (for keyframe prompts)",
|
| 61 |
+
"keyframes_button": "2. Generate Key-Images",
|
| 62 |
+
"keyframes_gallery_label": "Key-Scenes Gallery (Keyframes)",
|
| 63 |
+
"manual_keyframes_label": "Upload Keyframes Manually",
|
| 64 |
+
"manual_separator": "--- OR ---",
|
| 65 |
+
"step3_accordion": "Step 3: Film Production (Specialists: LTX & MMAudio)",
|
| 66 |
+
"step3_description": "The Continuity Director and Cinematographer will guide the Camera (LTX) to shoot the transitions between keyframes.",
|
| 67 |
+
"continuity_director_label": "Use AI Continuity Director (for cuts)",
|
| 68 |
+
"cinematographer_label": "Use AI Cinematographer (for motion prompts)",
|
| 69 |
+
"duration_label": "Duration per Scene (s)",
|
| 70 |
+
"n_corte_label": "Base Cut Point (%)",
|
| 71 |
+
"n_corte_info": "Base percentage of the scene to be replaced by the transition. Will be adjusted dynamically.",
|
| 72 |
+
"convergence_chunks_label": "Max Convergence Chunks",
|
| 73 |
+
"convergence_chunks_info": "Max number of latent chunks (memory) to guide motion convergence. Will be adjusted dynamically.",
|
| 74 |
+
"path_convergence_label": "Handler Strength (Tensor)",
|
| 75 |
+
"destination_convergence_label": "Destination Convergence (Tensor)",
|
| 76 |
+
"produce_button": "3. 🎬 Produce Complete Film (with Sound)",
|
| 77 |
+
"advanced_accordion_label": "Advanced Settings (LTX)",
|
| 78 |
+
"guidance_label": "Guidance Scale",
|
| 79 |
+
"stg_label": "STG Scale",
|
| 80 |
+
"rescaling_label": "Rescaling Scale",
|
| 81 |
+
"steps_label": "Inference Steps",
|
| 82 |
+
"steps_info": "More steps can improve quality but increase generation time. Ignored for 'distilled' models.",
|
| 83 |
+
"video_fragments_gallery_label": "Generated Film Fragments",
|
| 84 |
+
"final_movie_with_audio_label": "🎉 COMPLETE MOVIE 🎉"
|
| 85 |
+
},
|
| 86 |
+
"zh": {
|
| 87 |
+
"app_title": "ADUC-SDR 🎬 - 人工智能电影导演",
|
| 88 |
+
"app_subtitle": "由人工智能团队精心策划,根据一个想法和参考图像创作一部完整的有声电影。",
|
| 89 |
+
"lang_selector_label": "语言 / Language",
|
| 90 |
+
"step1_accordion": "第 1 步:剧本和关键场景",
|
| 91 |
+
"prompt_label": "电影总体构想",
|
| 92 |
+
"ref_images_base_label": "参考图像 (故事基础)",
|
| 93 |
+
"ref_images_extra_label": "附加图像 (摄影师模式的场景库)",
|
| 94 |
+
"keyframes_label": "关键场景数量",
|
| 95 |
+
"storyboard_button": "1. 生成剧本",
|
| 96 |
+
"storyboard_and_keyframes_button": "1A. 生成剧本和关键帧 (艺术总监模式)",
|
| 97 |
+
"storyboard_from_photos_button": "1B. 从照片生成剧本 (摄影师模式)",
|
| 98 |
+
"step1_mode_b_info": "摄影师模式:“附加图像”被用作场景库,AI将为剧本的每个部分选择最佳图像。",
|
| 99 |
+
"storyboard_output_label": "生成的剧本",
|
| 100 |
+
"step2_accordion": "第 2 步:关键帧 (专家: Flux)",
|
| 101 |
+
"step2_description": "艺术总监 (Gemini) 将指导画家 (Flux) 创作故事的关键图像。",
|
| 102 |
+
"art_director_label": "使用AI艺术总监",
|
| 103 |
+
"keyframes_button": "2. 生成关键图像",
|
| 104 |
+
"keyframes_gallery_label": "关键场景画廊 (关键帧)",
|
| 105 |
+
"manual_keyframes_label": "手动上传关键帧",
|
| 106 |
+
"manual_separator": "--- 或者 ---",
|
| 107 |
+
"step3_accordion": "第 3 步:影片制作 (专家: LTX & MMAudio)",
|
| 108 |
+
"step3_description": "连续性导演和电影摄影师将指导摄像机 (LTX) 拍摄关键帧之间的过渡。",
|
| 109 |
+
"continuity_director_label": "使用AI连续性导演",
|
| 110 |
+
"cinematographer_label": "使用AI电影摄影师",
|
| 111 |
+
"duration_label": "每场景时长 (秒)",
|
| 112 |
+
"n_corte_label": "基础剪辑点 (%)",
|
| 113 |
+
"n_corte_info": "将被过渡替换的场景基础百分比。将动态调整。",
|
| 114 |
+
"convergence_chunks_label": "最大收敛块",
|
| 115 |
+
"convergence_chunks_info": "用于引导运动收敛的最大潜在块(内存)数量。将动态调整。",
|
| 116 |
+
"path_convergence_label": "处理器强度 (张量)",
|
| 117 |
+
"destination_convergence_label": "目标收敛 (张量)",
|
| 118 |
+
"produce_button": "3. 🎬 制作完整影片 (有声)",
|
| 119 |
+
"advanced_accordion_label": "高级设置 (LTX)",
|
| 120 |
+
"guidance_label": "引导比例",
|
| 121 |
+
"stg_label": "STG 比例",
|
| 122 |
+
"rescaling_label": "重缩放比例",
|
| 123 |
+
"steps_label": "推理步骤",
|
| 124 |
+
"steps_info": "更多步骤可以提高质量,但会增加生成时间。对“distilled”模型无效。",
|
| 125 |
+
"video_fragments_gallery_label": "生成的电影片段",
|
| 126 |
+
"final_movie_with_audio_label": "🎉 完整影片 🎉"
|
| 127 |
+
}
|
| 128 |
+
}
|
image_specialist.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# image_specialist.py
|
| 2 |
+
# Copyright (C) 2025 Carlos Rodrigues dos Santos
|
| 3 |
+
#
|
| 4 |
+
# Este programa é software livre: você pode redistribuí-lo e/ou modificá-lo
|
| 5 |
+
# sob os termos da Licença Pública Geral Affero GNU...
|
| 6 |
+
# AVISO DE PATENTE PENDENTE: Consulte NOTICE.md.
|
| 7 |
+
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
from flux_kontext_helpers import flux_kontext_singleton
|
| 16 |
+
from gemini_helpers import gemini_singleton
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class ImageSpecialist:
|
| 21 |
+
"""
|
| 22 |
+
Especialista ADUC para a geração de imagens estáticas (keyframes).
|
| 23 |
+
É responsável por todo o processo de transformar um roteiro em uma galeria de keyframes.
|
| 24 |
+
"""
|
| 25 |
+
def __init__(self, workspace_dir):
|
| 26 |
+
self.workspace_dir = workspace_dir
|
| 27 |
+
self.image_generation_helper = flux_kontext_singleton
|
| 28 |
+
logger.info("Especialista de Imagem (Flux) pronto para receber ordens do Maestro.")
|
| 29 |
+
|
| 30 |
+
def _generate_single_keyframe(self, prompt: str, reference_images: list[Image.Image], output_filename: str, width: int, height: int, callback: callable = None) -> str:
|
| 31 |
+
"""
|
| 32 |
+
Função de baixo nível que gera uma única imagem.
|
| 33 |
+
"""
|
| 34 |
+
logger.info(f"Gerando keyframe '{output_filename}' com prompt: '{prompt}'")
|
| 35 |
+
generated_image = self.image_generation_helper.generate_image(
|
| 36 |
+
reference_images=reference_images, prompt=prompt, width=width,
|
| 37 |
+
height=height, seed=int(time.time()), callback=callback
|
| 38 |
+
)
|
| 39 |
+
final_path = os.path.join(self.workspace_dir, output_filename)
|
| 40 |
+
generated_image.save(final_path)
|
| 41 |
+
logger.info(f"Keyframe salvo com sucesso em: {final_path}")
|
| 42 |
+
return final_path
|
| 43 |
+
|
| 44 |
+
def generate_keyframes_from_storyboard(self, storyboard: list, initial_ref_path: str, global_prompt: str, keyframe_resolution: int, general_ref_paths: list, progress_callback_factory: callable = None):
|
| 45 |
+
"""
|
| 46 |
+
Orquestra a geração de todos os keyframes a partir de um storyboard.
|
| 47 |
+
"""
|
| 48 |
+
current_base_image_path = initial_ref_path
|
| 49 |
+
previous_prompt = "N/A (imagem inicial de referência)"
|
| 50 |
+
final_keyframes = [current_base_image_path]
|
| 51 |
+
width, height = keyframe_resolution, keyframe_resolution
|
| 52 |
+
|
| 53 |
+
# O número de keyframes a gerar é len(storyboard) - 1, pois o primeiro keyframe já existe (initial_ref_path)
|
| 54 |
+
# E o storyboard tem o mesmo número de elementos que o número total de keyframes desejados.
|
| 55 |
+
num_keyframes_to_generate = len(storyboard) - 1
|
| 56 |
+
|
| 57 |
+
logger.info(f"ESPECIALISTA DE IMAGEM: Recebi ordem para gerar {num_keyframes_to_generate} keyframes.")
|
| 58 |
+
|
| 59 |
+
for i in range(num_keyframes_to_generate):
|
| 60 |
+
# A cena atual é a transição de storyboard[i] para storyboard[i+1]
|
| 61 |
+
current_scene = storyboard[i]
|
| 62 |
+
future_scene = storyboard[i+1]
|
| 63 |
+
progress_callback = progress_callback_factory(i + 1, num_keyframes_to_generate) if progress_callback_factory else None
|
| 64 |
+
|
| 65 |
+
logger.info(f"--> Gerando Keyframe {i+1}/{num_keyframes_to_generate}...")
|
| 66 |
+
|
| 67 |
+
# O próprio especialista consulta o Gemini para o prompt de imagem
|
| 68 |
+
new_flux_prompt = gemini_singleton.get_anticipatory_keyframe_prompt(
|
| 69 |
+
global_prompt=global_prompt, scene_history=previous_prompt,
|
| 70 |
+
current_scene_desc=current_scene, future_scene_desc=future_scene,
|
| 71 |
+
last_image_path=current_base_image_path, fixed_ref_paths=general_ref_paths
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
images_for_flux_paths = list(set([current_base_image_path] + general_ref_paths))
|
| 75 |
+
images_for_flux = [Image.open(p) for p in images_for_flux_paths]
|
| 76 |
+
|
| 77 |
+
new_keyframe_path = self._generate_single_keyframe(
|
| 78 |
+
prompt=new_flux_prompt, reference_images=images_for_flux,
|
| 79 |
+
output_filename=f"keyframe_{i+1}.png", width=width, height=height,
|
| 80 |
+
callback=progress_callback
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
final_keyframes.append(new_keyframe_path)
|
| 84 |
+
current_base_image_path = new_keyframe_path
|
| 85 |
+
previous_prompt = new_flux_prompt
|
| 86 |
+
|
| 87 |
+
logger.info(f"ESPECIALISTA DE IMAGEM: Geração de keyframes concluída.")
|
| 88 |
+
return final_keyframes
|
| 89 |
+
|
| 90 |
+
# Singleton instantiation - usa o workspace_dir da config
|
| 91 |
+
try:
|
| 92 |
+
with open("config.yaml", 'r') as f:
|
| 93 |
+
config = yaml.safe_load(f)
|
| 94 |
+
WORKSPACE_DIR = config['application']['workspace_dir']
|
| 95 |
+
image_specialist_singleton = ImageSpecialist(workspace_dir=WORKSPACE_DIR)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Não foi possível inicializar o ImageSpecialist: {e}", exc_info=True)
|
| 98 |
+
image_specialist_singleton = None
|
inference.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from diffusers.utils import logging
|
| 7 |
+
from typing import Optional, List, Union
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
import imageio
|
| 11 |
+
import json
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import cv2
|
| 15 |
+
from safetensors import safe_open
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from transformers import (
|
| 18 |
+
T5EncoderModel,
|
| 19 |
+
T5Tokenizer,
|
| 20 |
+
AutoModelForCausalLM,
|
| 21 |
+
AutoProcessor,
|
| 22 |
+
AutoTokenizer,
|
| 23 |
+
)
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
|
| 26 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| 27 |
+
CausalVideoAutoencoder,
|
| 28 |
+
)
|
| 29 |
+
from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
|
| 30 |
+
from ltx_video.models.transformers.transformer3d import Transformer3DModel
|
| 31 |
+
from ltx_video.pipelines.pipeline_ltx_video import (
|
| 32 |
+
ConditioningItem,
|
| 33 |
+
LTXVideoPipeline,
|
| 34 |
+
LTXMultiScalePipeline,
|
| 35 |
+
)
|
| 36 |
+
from ltx_video.schedulers.rf import RectifiedFlowScheduler
|
| 37 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 38 |
+
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
|
| 39 |
+
import ltx_video.pipelines.crf_compressor as crf_compressor
|
| 40 |
+
|
| 41 |
+
MAX_HEIGHT = 720
|
| 42 |
+
MAX_WIDTH = 1280
|
| 43 |
+
MAX_NUM_FRAMES = 257
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger("LTX-Video")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_total_gpu_memory():
|
| 49 |
+
if torch.cuda.is_available():
|
| 50 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 51 |
+
return total_memory
|
| 52 |
+
return 44
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_device():
|
| 56 |
+
if torch.cuda.is_available():
|
| 57 |
+
return "cuda"
|
| 58 |
+
elif torch.backends.mps.is_available():
|
| 59 |
+
return "mps"
|
| 60 |
+
return "cuda"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_image_to_tensor_with_resize_and_crop(
|
| 64 |
+
image_input: Union[str, Image.Image],
|
| 65 |
+
target_height: int = 512,
|
| 66 |
+
target_width: int = 768,
|
| 67 |
+
just_crop: bool = False,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""Load and process an image into a tensor.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
image_input: Either a file path (str) or a PIL Image object
|
| 73 |
+
target_height: Desired height of output tensor
|
| 74 |
+
target_width: Desired width of output tensor
|
| 75 |
+
just_crop: If True, only crop the image to the target size without resizing
|
| 76 |
+
"""
|
| 77 |
+
if isinstance(image_input, str):
|
| 78 |
+
image = Image.open(image_input).convert("RGB")
|
| 79 |
+
elif isinstance(image_input, Image.Image):
|
| 80 |
+
image = image_input
|
| 81 |
+
else:
|
| 82 |
+
raise ValueError("image_input must be either a file path or a PIL Image object")
|
| 83 |
+
|
| 84 |
+
input_width, input_height = image.size
|
| 85 |
+
aspect_ratio_target = target_width / target_height
|
| 86 |
+
aspect_ratio_frame = input_width / input_height
|
| 87 |
+
if aspect_ratio_frame > aspect_ratio_target:
|
| 88 |
+
new_width = int(input_height * aspect_ratio_target)
|
| 89 |
+
new_height = input_height
|
| 90 |
+
x_start = (input_width - new_width) // 2
|
| 91 |
+
y_start = 0
|
| 92 |
+
else:
|
| 93 |
+
new_width = input_width
|
| 94 |
+
new_height = int(input_width / aspect_ratio_target)
|
| 95 |
+
x_start = 0
|
| 96 |
+
y_start = (input_height - new_height) // 2
|
| 97 |
+
|
| 98 |
+
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
|
| 99 |
+
if not just_crop:
|
| 100 |
+
image = image.resize((target_width, target_height))
|
| 101 |
+
|
| 102 |
+
image = np.array(image)
|
| 103 |
+
image = cv2.GaussianBlur(image, (3, 3), 0)
|
| 104 |
+
frame_tensor = torch.from_numpy(image).float()
|
| 105 |
+
frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
|
| 106 |
+
frame_tensor = frame_tensor.permute(2, 0, 1)
|
| 107 |
+
frame_tensor = (frame_tensor / 127.5) - 1.0
|
| 108 |
+
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
|
| 109 |
+
return frame_tensor.unsqueeze(0).unsqueeze(2)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def calculate_padding(
|
| 113 |
+
source_height: int, source_width: int, target_height: int, target_width: int
|
| 114 |
+
) -> tuple[int, int, int, int]:
|
| 115 |
+
|
| 116 |
+
# Calculate total padding needed
|
| 117 |
+
pad_height = target_height - source_height
|
| 118 |
+
pad_width = target_width - source_width
|
| 119 |
+
|
| 120 |
+
# Calculate padding for each side
|
| 121 |
+
pad_top = pad_height // 2
|
| 122 |
+
pad_bottom = pad_height - pad_top # Handles odd padding
|
| 123 |
+
pad_left = pad_width // 2
|
| 124 |
+
pad_right = pad_width - pad_left # Handles odd padding
|
| 125 |
+
|
| 126 |
+
# Return padded tensor
|
| 127 |
+
# Padding format is (left, right, top, bottom)
|
| 128 |
+
padding = (pad_left, pad_right, pad_top, pad_bottom)
|
| 129 |
+
return padding
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
|
| 133 |
+
# Remove non-letters and convert to lowercase
|
| 134 |
+
clean_text = "".join(
|
| 135 |
+
char.lower() for char in text if char.isalpha() or char.isspace()
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Split into words
|
| 139 |
+
words = clean_text.split()
|
| 140 |
+
|
| 141 |
+
# Build result string keeping track of length
|
| 142 |
+
result = []
|
| 143 |
+
current_length = 0
|
| 144 |
+
|
| 145 |
+
for word in words:
|
| 146 |
+
# Add word length plus 1 for underscore (except for first word)
|
| 147 |
+
new_length = current_length + len(word)
|
| 148 |
+
|
| 149 |
+
if new_length <= max_len:
|
| 150 |
+
result.append(word)
|
| 151 |
+
current_length += len(word)
|
| 152 |
+
else:
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
return "-".join(result)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# Generate output video name
|
| 159 |
+
def get_unique_filename(
|
| 160 |
+
base: str,
|
| 161 |
+
ext: str,
|
| 162 |
+
prompt: str,
|
| 163 |
+
seed: int,
|
| 164 |
+
resolution: tuple[int, int, int],
|
| 165 |
+
dir: Path,
|
| 166 |
+
endswith=None,
|
| 167 |
+
index_range=1000,
|
| 168 |
+
) -> Path:
|
| 169 |
+
base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
|
| 170 |
+
for i in range(index_range):
|
| 171 |
+
filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
|
| 172 |
+
if not os.path.exists(filename):
|
| 173 |
+
return filename
|
| 174 |
+
raise FileExistsError(
|
| 175 |
+
f"Could not find a unique filename after {index_range} attempts."
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def seed_everething(seed: int):
|
| 180 |
+
random.seed(seed)
|
| 181 |
+
np.random.seed(seed)
|
| 182 |
+
torch.manual_seed(seed)
|
| 183 |
+
if torch.cuda.is_available():
|
| 184 |
+
torch.cuda.manual_seed(seed)
|
| 185 |
+
if torch.backends.mps.is_available():
|
| 186 |
+
torch.mps.manual_seed(seed)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main():
|
| 190 |
+
parser = argparse.ArgumentParser(
|
| 191 |
+
description="Load models from separate directories and run the pipeline."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Directories
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--output_path",
|
| 197 |
+
type=str,
|
| 198 |
+
default=None,
|
| 199 |
+
help="Path to the folder to save output video, if None will save in outputs/ directory.",
|
| 200 |
+
)
|
| 201 |
+
parser.add_argument("--seed", type=int, default="171198")
|
| 202 |
+
|
| 203 |
+
# Pipeline parameters
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--num_images_per_prompt",
|
| 206 |
+
type=int,
|
| 207 |
+
default=1,
|
| 208 |
+
help="Number of images per prompt",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--image_cond_noise_scale",
|
| 212 |
+
type=float,
|
| 213 |
+
default=0.15,
|
| 214 |
+
help="Amount of noise to add to the conditioned image",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--height",
|
| 218 |
+
type=int,
|
| 219 |
+
default=704,
|
| 220 |
+
help="Height of the output video frames. Optional if an input image provided.",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--width",
|
| 224 |
+
type=int,
|
| 225 |
+
default=1216,
|
| 226 |
+
help="Width of the output video frames. If None will infer from input image.",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--num_frames",
|
| 230 |
+
type=int,
|
| 231 |
+
default=121,
|
| 232 |
+
help="Number of frames to generate in the output video",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--frame_rate", type=int, default=30, help="Frame rate for the output video"
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--device",
|
| 239 |
+
default=None,
|
| 240 |
+
help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--pipeline_config",
|
| 244 |
+
type=str,
|
| 245 |
+
default="configs/ltxv-13b-0.9.7-dev.yaml",
|
| 246 |
+
help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Prompts
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--prompt",
|
| 252 |
+
type=str,
|
| 253 |
+
help="Text prompt to guide generation",
|
| 254 |
+
)
|
| 255 |
+
parser.add_argument(
|
| 256 |
+
"--negative_prompt",
|
| 257 |
+
type=str,
|
| 258 |
+
default="worst quality, inconsistent motion, blurry, jittery, distorted",
|
| 259 |
+
help="Negative prompt for undesired features",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--offload_to_cpu",
|
| 264 |
+
action="store_true",
|
| 265 |
+
help="Offloading unnecessary computations to CPU.",
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# video-to-video arguments:
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--input_media_path",
|
| 271 |
+
type=str,
|
| 272 |
+
default=None,
|
| 273 |
+
help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Conditioning arguments
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--conditioning_media_paths",
|
| 279 |
+
type=str,
|
| 280 |
+
nargs="*",
|
| 281 |
+
help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--conditioning_strengths",
|
| 285 |
+
type=float,
|
| 286 |
+
nargs="*",
|
| 287 |
+
help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
|
| 288 |
+
)
|
| 289 |
+
parser.add_argument(
|
| 290 |
+
"--conditioning_start_frames",
|
| 291 |
+
type=int,
|
| 292 |
+
nargs="*",
|
| 293 |
+
help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
logger.warning(f"Running generation with arguments: {args}")
|
| 298 |
+
infer(**vars(args))
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def create_ltx_video_pipeline(
|
| 302 |
+
ckpt_path: str,
|
| 303 |
+
precision: str,
|
| 304 |
+
text_encoder_model_name_or_path: str,
|
| 305 |
+
sampler: Optional[str] = None,
|
| 306 |
+
device: Optional[str] = None,
|
| 307 |
+
enhance_prompt: bool = False,
|
| 308 |
+
prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
|
| 309 |
+
prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
|
| 310 |
+
) -> LTXVideoPipeline:
|
| 311 |
+
ckpt_path = Path(ckpt_path)
|
| 312 |
+
assert os.path.exists(
|
| 313 |
+
ckpt_path
|
| 314 |
+
), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
|
| 315 |
+
|
| 316 |
+
with safe_open(ckpt_path, framework="pt") as f:
|
| 317 |
+
metadata = f.metadata()
|
| 318 |
+
config_str = metadata.get("config")
|
| 319 |
+
configs = json.loads(config_str)
|
| 320 |
+
allowed_inference_steps = configs.get("allowed_inference_steps", None)
|
| 321 |
+
|
| 322 |
+
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
|
| 323 |
+
transformer = Transformer3DModel.from_pretrained(ckpt_path)
|
| 324 |
+
|
| 325 |
+
# Use constructor if sampler is specified, otherwise use from_pretrained
|
| 326 |
+
if sampler == "from_checkpoint" or not sampler:
|
| 327 |
+
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
|
| 328 |
+
else:
|
| 329 |
+
scheduler = RectifiedFlowScheduler(
|
| 330 |
+
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
| 334 |
+
text_encoder_model_name_or_path, subfolder="text_encoder"
|
| 335 |
+
)
|
| 336 |
+
patchifier = SymmetricPatchifier(patch_size=1)
|
| 337 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
| 338 |
+
text_encoder_model_name_or_path, subfolder="tokenizer"
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
transformer = transformer.to(device)
|
| 342 |
+
vae = vae.to(device)
|
| 343 |
+
text_encoder = text_encoder.to(device)
|
| 344 |
+
|
| 345 |
+
if enhance_prompt:
|
| 346 |
+
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
|
| 347 |
+
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
|
| 348 |
+
)
|
| 349 |
+
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
|
| 350 |
+
prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
|
| 351 |
+
)
|
| 352 |
+
prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
|
| 353 |
+
prompt_enhancer_llm_model_name_or_path,
|
| 354 |
+
torch_dtype="bfloat16",
|
| 355 |
+
)
|
| 356 |
+
prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
|
| 357 |
+
prompt_enhancer_llm_model_name_or_path,
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
prompt_enhancer_image_caption_model = None
|
| 361 |
+
prompt_enhancer_image_caption_processor = None
|
| 362 |
+
prompt_enhancer_llm_model = None
|
| 363 |
+
prompt_enhancer_llm_tokenizer = None
|
| 364 |
+
|
| 365 |
+
vae = vae.to(torch.bfloat16)
|
| 366 |
+
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
|
| 367 |
+
transformer = transformer.to(torch.bfloat16)
|
| 368 |
+
text_encoder = text_encoder.to(torch.bfloat16)
|
| 369 |
+
|
| 370 |
+
# Use submodels for the pipeline
|
| 371 |
+
submodel_dict = {
|
| 372 |
+
"transformer": transformer,
|
| 373 |
+
"patchifier": patchifier,
|
| 374 |
+
"text_encoder": text_encoder,
|
| 375 |
+
"tokenizer": tokenizer,
|
| 376 |
+
"scheduler": scheduler,
|
| 377 |
+
"vae": vae,
|
| 378 |
+
"prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
|
| 379 |
+
"prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
|
| 380 |
+
"prompt_enhancer_llm_model": prompt_enhancer_llm_model,
|
| 381 |
+
"prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
|
| 382 |
+
"allowed_inference_steps": allowed_inference_steps,
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
pipeline = LTXVideoPipeline(**submodel_dict)
|
| 386 |
+
pipeline = pipeline.to(device)
|
| 387 |
+
return pipeline
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
|
| 391 |
+
latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
|
| 392 |
+
latent_upsampler.to(device)
|
| 393 |
+
latent_upsampler.eval()
|
| 394 |
+
return latent_upsampler
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def infer(
|
| 398 |
+
output_path: Optional[str],
|
| 399 |
+
seed: int,
|
| 400 |
+
pipeline_config: str,
|
| 401 |
+
image_cond_noise_scale: float,
|
| 402 |
+
height: Optional[int],
|
| 403 |
+
width: Optional[int],
|
| 404 |
+
num_frames: int,
|
| 405 |
+
frame_rate: int,
|
| 406 |
+
prompt: str,
|
| 407 |
+
negative_prompt: str,
|
| 408 |
+
offload_to_cpu: bool,
|
| 409 |
+
input_media_path: Optional[str] = None,
|
| 410 |
+
conditioning_media_paths: Optional[List[str]] = None,
|
| 411 |
+
conditioning_strengths: Optional[List[float]] = None,
|
| 412 |
+
conditioning_start_frames: Optional[List[int]] = None,
|
| 413 |
+
device: Optional[str] = None,
|
| 414 |
+
**kwargs,
|
| 415 |
+
):
|
| 416 |
+
# check if pipeline_config is a file
|
| 417 |
+
if not os.path.isfile(pipeline_config):
|
| 418 |
+
raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
|
| 419 |
+
with open(pipeline_config, "r") as f:
|
| 420 |
+
pipeline_config = yaml.safe_load(f)
|
| 421 |
+
|
| 422 |
+
models_dir = "MODEL_DIR"
|
| 423 |
+
|
| 424 |
+
ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
|
| 425 |
+
if not os.path.isfile(ltxv_model_name_or_path):
|
| 426 |
+
ltxv_model_path = hf_hub_download(
|
| 427 |
+
repo_id="Lightricks/LTX-Video",
|
| 428 |
+
filename=ltxv_model_name_or_path,
|
| 429 |
+
local_dir=models_dir,
|
| 430 |
+
repo_type="model",
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
ltxv_model_path = ltxv_model_name_or_path
|
| 434 |
+
|
| 435 |
+
spatial_upscaler_model_name_or_path = pipeline_config.get(
|
| 436 |
+
"spatial_upscaler_model_path"
|
| 437 |
+
)
|
| 438 |
+
if spatial_upscaler_model_name_or_path and not os.path.isfile(
|
| 439 |
+
spatial_upscaler_model_name_or_path
|
| 440 |
+
):
|
| 441 |
+
spatial_upscaler_model_path = hf_hub_download(
|
| 442 |
+
repo_id="Lightricks/LTX-Video",
|
| 443 |
+
filename=spatial_upscaler_model_name_or_path,
|
| 444 |
+
local_dir=models_dir,
|
| 445 |
+
repo_type="model",
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
|
| 449 |
+
|
| 450 |
+
if kwargs.get("input_image_path", None):
|
| 451 |
+
logger.warning(
|
| 452 |
+
"Please use conditioning_media_paths instead of input_image_path."
|
| 453 |
+
)
|
| 454 |
+
assert not conditioning_media_paths and not conditioning_start_frames
|
| 455 |
+
conditioning_media_paths = [kwargs["input_image_path"]]
|
| 456 |
+
conditioning_start_frames = [0]
|
| 457 |
+
|
| 458 |
+
# Validate conditioning arguments
|
| 459 |
+
if conditioning_media_paths:
|
| 460 |
+
# Use default strengths of 1.0
|
| 461 |
+
if not conditioning_strengths:
|
| 462 |
+
conditioning_strengths = [1.0] * len(conditioning_media_paths)
|
| 463 |
+
if not conditioning_start_frames:
|
| 464 |
+
raise ValueError(
|
| 465 |
+
"If `conditioning_media_paths` is provided, "
|
| 466 |
+
"`conditioning_start_frames` must also be provided"
|
| 467 |
+
)
|
| 468 |
+
if len(conditioning_media_paths) != len(conditioning_strengths) or len(
|
| 469 |
+
conditioning_media_paths
|
| 470 |
+
) != len(conditioning_start_frames):
|
| 471 |
+
raise ValueError(
|
| 472 |
+
"`conditioning_media_paths`, `conditioning_strengths`, "
|
| 473 |
+
"and `conditioning_start_frames` must have the same length"
|
| 474 |
+
)
|
| 475 |
+
if any(s < 0 or s > 1 for s in conditioning_strengths):
|
| 476 |
+
raise ValueError("All conditioning strengths must be between 0 and 1")
|
| 477 |
+
if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
|
| 478 |
+
raise ValueError(
|
| 479 |
+
f"All conditioning start frames must be between 0 and {num_frames-1}"
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
seed_everething(seed)
|
| 483 |
+
if offload_to_cpu and not torch.cuda.is_available():
|
| 484 |
+
logger.warning(
|
| 485 |
+
"offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
|
| 486 |
+
)
|
| 487 |
+
offload_to_cpu = False
|
| 488 |
+
else:
|
| 489 |
+
offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
|
| 490 |
+
|
| 491 |
+
output_dir = (
|
| 492 |
+
Path(output_path)
|
| 493 |
+
if output_path
|
| 494 |
+
else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
|
| 495 |
+
)
|
| 496 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 497 |
+
|
| 498 |
+
# Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
|
| 499 |
+
height_padded = ((height - 1) // 32 + 1) * 32
|
| 500 |
+
width_padded = ((width - 1) // 32 + 1) * 32
|
| 501 |
+
num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
|
| 502 |
+
|
| 503 |
+
padding = calculate_padding(height, width, height_padded, width_padded)
|
| 504 |
+
|
| 505 |
+
logger.warning(
|
| 506 |
+
f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
prompt_enhancement_words_threshold = pipeline_config[
|
| 510 |
+
"prompt_enhancement_words_threshold"
|
| 511 |
+
]
|
| 512 |
+
|
| 513 |
+
prompt_word_count = len(prompt.split())
|
| 514 |
+
enhance_prompt = (
|
| 515 |
+
prompt_enhancement_words_threshold > 0
|
| 516 |
+
and prompt_word_count < prompt_enhancement_words_threshold
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
|
| 520 |
+
logger.info(
|
| 521 |
+
f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
precision = pipeline_config["precision"]
|
| 525 |
+
text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
|
| 526 |
+
sampler = pipeline_config["sampler"]
|
| 527 |
+
prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
|
| 528 |
+
"prompt_enhancer_image_caption_model_name_or_path"
|
| 529 |
+
]
|
| 530 |
+
prompt_enhancer_llm_model_name_or_path = pipeline_config[
|
| 531 |
+
"prompt_enhancer_llm_model_name_or_path"
|
| 532 |
+
]
|
| 533 |
+
|
| 534 |
+
pipeline = create_ltx_video_pipeline(
|
| 535 |
+
ckpt_path=ltxv_model_path,
|
| 536 |
+
precision=precision,
|
| 537 |
+
text_encoder_model_name_or_path=text_encoder_model_name_or_path,
|
| 538 |
+
sampler=sampler,
|
| 539 |
+
device=kwargs.get("device", get_device()),
|
| 540 |
+
enhance_prompt=enhance_prompt,
|
| 541 |
+
prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
|
| 542 |
+
prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
if pipeline_config.get("pipeline_type", None) == "multi-scale":
|
| 546 |
+
if not spatial_upscaler_model_path:
|
| 547 |
+
raise ValueError(
|
| 548 |
+
"spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
|
| 549 |
+
)
|
| 550 |
+
latent_upsampler = create_latent_upsampler(
|
| 551 |
+
spatial_upscaler_model_path, pipeline.device
|
| 552 |
+
)
|
| 553 |
+
pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
|
| 554 |
+
|
| 555 |
+
media_item = None
|
| 556 |
+
if input_media_path:
|
| 557 |
+
media_item = load_media_file(
|
| 558 |
+
media_path=input_media_path,
|
| 559 |
+
height=height,
|
| 560 |
+
width=width,
|
| 561 |
+
max_frames=num_frames_padded,
|
| 562 |
+
padding=padding,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
conditioning_items = (
|
| 566 |
+
prepare_conditioning(
|
| 567 |
+
conditioning_media_paths=conditioning_media_paths,
|
| 568 |
+
conditioning_strengths=conditioning_strengths,
|
| 569 |
+
conditioning_start_frames=conditioning_start_frames,
|
| 570 |
+
height=height,
|
| 571 |
+
width=width,
|
| 572 |
+
num_frames=num_frames,
|
| 573 |
+
padding=padding,
|
| 574 |
+
pipeline=pipeline,
|
| 575 |
+
)
|
| 576 |
+
if conditioning_media_paths
|
| 577 |
+
else None
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
stg_mode = pipeline_config.get("stg_mode", "attention_values")
|
| 581 |
+
del pipeline_config["stg_mode"]
|
| 582 |
+
if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
|
| 583 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionValues
|
| 584 |
+
elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
|
| 585 |
+
skip_layer_strategy = SkipLayerStrategy.AttentionSkip
|
| 586 |
+
elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
|
| 587 |
+
skip_layer_strategy = SkipLayerStrategy.Residual
|
| 588 |
+
elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
|
| 589 |
+
skip_layer_strategy = SkipLayerStrategy.TransformerBlock
|
| 590 |
+
else:
|
| 591 |
+
raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
|
| 592 |
+
|
| 593 |
+
# Prepare input for the pipeline
|
| 594 |
+
sample = {
|
| 595 |
+
"prompt": prompt,
|
| 596 |
+
"prompt_attention_mask": None,
|
| 597 |
+
"negative_prompt": negative_prompt,
|
| 598 |
+
"negative_prompt_attention_mask": None,
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
device = device or get_device()
|
| 602 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 603 |
+
|
| 604 |
+
images = pipeline(
|
| 605 |
+
**pipeline_config,
|
| 606 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 607 |
+
generator=generator,
|
| 608 |
+
output_type="pt",
|
| 609 |
+
callback_on_step_end=None,
|
| 610 |
+
height=height_padded,
|
| 611 |
+
width=width_padded,
|
| 612 |
+
num_frames=num_frames_padded,
|
| 613 |
+
frame_rate=frame_rate,
|
| 614 |
+
**sample,
|
| 615 |
+
media_items=media_item,
|
| 616 |
+
conditioning_items=conditioning_items,
|
| 617 |
+
is_video=True,
|
| 618 |
+
vae_per_channel_normalize=True,
|
| 619 |
+
image_cond_noise_scale=image_cond_noise_scale,
|
| 620 |
+
mixed_precision=(precision == "mixed_precision"),
|
| 621 |
+
offload_to_cpu=offload_to_cpu,
|
| 622 |
+
device=device,
|
| 623 |
+
enhance_prompt=enhance_prompt,
|
| 624 |
+
).images
|
| 625 |
+
|
| 626 |
+
# Crop the padded images to the desired resolution and number of frames
|
| 627 |
+
(pad_left, pad_right, pad_top, pad_bottom) = padding
|
| 628 |
+
pad_bottom = -pad_bottom
|
| 629 |
+
pad_right = -pad_right
|
| 630 |
+
if pad_bottom == 0:
|
| 631 |
+
pad_bottom = images.shape[3]
|
| 632 |
+
if pad_right == 0:
|
| 633 |
+
pad_right = images.shape[4]
|
| 634 |
+
images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
|
| 635 |
+
|
| 636 |
+
for i in range(images.shape[0]):
|
| 637 |
+
# Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
|
| 638 |
+
video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
|
| 639 |
+
# Unnormalizing images to [0, 255] range
|
| 640 |
+
video_np = (video_np * 255).astype(np.uint8)
|
| 641 |
+
fps = frame_rate
|
| 642 |
+
height, width = video_np.shape[1:3]
|
| 643 |
+
# In case a single image is generated
|
| 644 |
+
if video_np.shape[0] == 1:
|
| 645 |
+
output_filename = get_unique_filename(
|
| 646 |
+
f"image_output_{i}",
|
| 647 |
+
".png",
|
| 648 |
+
prompt=prompt,
|
| 649 |
+
seed=seed,
|
| 650 |
+
resolution=(height, width, num_frames),
|
| 651 |
+
dir=output_dir,
|
| 652 |
+
)
|
| 653 |
+
imageio.imwrite(output_filename, video_np[0])
|
| 654 |
+
else:
|
| 655 |
+
output_filename = get_unique_filename(
|
| 656 |
+
f"video_output_{i}",
|
| 657 |
+
".mp4",
|
| 658 |
+
prompt=prompt,
|
| 659 |
+
seed=seed,
|
| 660 |
+
resolution=(height, width, num_frames),
|
| 661 |
+
dir=output_dir,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
# Write video
|
| 665 |
+
with imageio.get_writer(output_filename, fps=fps) as video:
|
| 666 |
+
for frame in video_np:
|
| 667 |
+
video.append_data(frame)
|
| 668 |
+
|
| 669 |
+
logger.warning(f"Output saved to {output_filename}")
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
def prepare_conditioning(
|
| 673 |
+
conditioning_media_paths: List[str],
|
| 674 |
+
conditioning_strengths: List[float],
|
| 675 |
+
conditioning_start_frames: List[int],
|
| 676 |
+
height: int,
|
| 677 |
+
width: int,
|
| 678 |
+
num_frames: int,
|
| 679 |
+
padding: tuple[int, int, int, int],
|
| 680 |
+
pipeline: LTXVideoPipeline,
|
| 681 |
+
) -> Optional[List[ConditioningItem]]:
|
| 682 |
+
"""Prepare conditioning items based on input media paths and their parameters.
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
conditioning_media_paths: List of paths to conditioning media (images or videos)
|
| 686 |
+
conditioning_strengths: List of conditioning strengths for each media item
|
| 687 |
+
conditioning_start_frames: List of frame indices where each item should be applied
|
| 688 |
+
height: Height of the output frames
|
| 689 |
+
width: Width of the output frames
|
| 690 |
+
num_frames: Number of frames in the output video
|
| 691 |
+
padding: Padding to apply to the frames
|
| 692 |
+
pipeline: LTXVideoPipeline object used for condition video trimming
|
| 693 |
+
|
| 694 |
+
Returns:
|
| 695 |
+
A list of ConditioningItem objects.
|
| 696 |
+
"""
|
| 697 |
+
conditioning_items = []
|
| 698 |
+
for path, strength, start_frame in zip(
|
| 699 |
+
conditioning_media_paths, conditioning_strengths, conditioning_start_frames
|
| 700 |
+
):
|
| 701 |
+
num_input_frames = orig_num_input_frames = get_media_num_frames(path)
|
| 702 |
+
if hasattr(pipeline, "trim_conditioning_sequence") and callable(
|
| 703 |
+
getattr(pipeline, "trim_conditioning_sequence")
|
| 704 |
+
):
|
| 705 |
+
num_input_frames = pipeline.trim_conditioning_sequence(
|
| 706 |
+
start_frame, orig_num_input_frames, num_frames
|
| 707 |
+
)
|
| 708 |
+
if num_input_frames < orig_num_input_frames:
|
| 709 |
+
logger.warning(
|
| 710 |
+
f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
|
| 711 |
+
)
|
| 712 |
+
|
| 713 |
+
media_tensor = load_media_file(
|
| 714 |
+
media_path=path,
|
| 715 |
+
height=height,
|
| 716 |
+
width=width,
|
| 717 |
+
max_frames=num_input_frames,
|
| 718 |
+
padding=padding,
|
| 719 |
+
just_crop=True,
|
| 720 |
+
)
|
| 721 |
+
conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
|
| 722 |
+
return conditioning_items
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def get_media_num_frames(media_path: str) -> int:
|
| 726 |
+
is_video = any(
|
| 727 |
+
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
|
| 728 |
+
)
|
| 729 |
+
num_frames = 1
|
| 730 |
+
if is_video:
|
| 731 |
+
reader = imageio.get_reader(media_path)
|
| 732 |
+
num_frames = reader.count_frames()
|
| 733 |
+
reader.close()
|
| 734 |
+
return num_frames
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def load_media_file(
|
| 738 |
+
media_path: str,
|
| 739 |
+
height: int,
|
| 740 |
+
width: int,
|
| 741 |
+
max_frames: int,
|
| 742 |
+
padding: tuple[int, int, int, int],
|
| 743 |
+
just_crop: bool = False,
|
| 744 |
+
) -> torch.Tensor:
|
| 745 |
+
is_video = any(
|
| 746 |
+
media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
|
| 747 |
+
)
|
| 748 |
+
if is_video:
|
| 749 |
+
reader = imageio.get_reader(media_path)
|
| 750 |
+
num_input_frames = min(reader.count_frames(), max_frames)
|
| 751 |
+
|
| 752 |
+
# Read and preprocess the relevant frames from the video file.
|
| 753 |
+
frames = []
|
| 754 |
+
for i in range(num_input_frames):
|
| 755 |
+
frame = Image.fromarray(reader.get_data(i))
|
| 756 |
+
frame_tensor = load_image_to_tensor_with_resize_and_crop(
|
| 757 |
+
frame, height, width, just_crop=just_crop
|
| 758 |
+
)
|
| 759 |
+
frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
|
| 760 |
+
frames.append(frame_tensor)
|
| 761 |
+
reader.close()
|
| 762 |
+
|
| 763 |
+
# Stack frames along the temporal dimension
|
| 764 |
+
media_tensor = torch.cat(frames, dim=2)
|
| 765 |
+
else: # Input image
|
| 766 |
+
media_tensor = load_image_to_tensor_with_resize_and_crop(
|
| 767 |
+
media_path, height, width, just_crop=just_crop
|
| 768 |
+
)
|
| 769 |
+
media_tensor = torch.nn.functional.pad(media_tensor, padding)
|
| 770 |
+
return media_tensor
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
if __name__ == "__main__":
|
| 774 |
+
main()
|
ltx_manager_helpers.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ltx_manager_helpers.py
|
| 2 |
+
# Copyright (C) 4 de Agosto de 2025 Carlos Rodrigues dos Santos
|
| 3 |
+
#
|
| 4 |
+
# ORIGINAL SOURCE: LTX-Video by Lightricks Ltd. & other open-source projects.
|
| 5 |
+
# Licensed under the Apache License, Version 2.0
|
| 6 |
+
# https://github.com/Lightricks/LTX-Video
|
| 7 |
+
#
|
| 8 |
+
# MODIFICATIONS FOR ADUC-SDR_Video:
|
| 9 |
+
# This file is part of ADUC-SDR_Video, a derivative work based on LTX-Video.
|
| 10 |
+
# It has been modified to manage pools of LTX workers, handle GPU memory,
|
| 11 |
+
# and prepare parameters for the ADUC-SDR orchestration framework.
|
| 12 |
+
# All modifications are also licensed under the Apache License, Version 2.0.
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import gc
|
| 16 |
+
import os
|
| 17 |
+
import yaml
|
| 18 |
+
import logging
|
| 19 |
+
import huggingface_hub
|
| 20 |
+
import time
|
| 21 |
+
import threading
|
| 22 |
+
import json
|
| 23 |
+
|
| 24 |
+
from optimization import optimize_ltx_worker, can_optimize_fp8
|
| 25 |
+
from hardware_manager import hardware_manager
|
| 26 |
+
from inference import create_ltx_video_pipeline, calculate_padding
|
| 27 |
+
from ltx_video.pipelines.pipeline_ltx_video import LatentConditioningItem
|
| 28 |
+
from ltx_video.models.autoencoders.vae_encode import vae_decode
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
class LtxWorker:
|
| 33 |
+
def __init__(self, device_id, ltx_config_file):
|
| 34 |
+
self.cpu_device = torch.device('cpu')
|
| 35 |
+
self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
|
| 36 |
+
logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
|
| 37 |
+
|
| 38 |
+
with open(ltx_config_file, "r") as file:
|
| 39 |
+
self.config = yaml.safe_load(file)
|
| 40 |
+
|
| 41 |
+
self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
|
| 42 |
+
|
| 43 |
+
models_dir = "downloaded_models_gradio"
|
| 44 |
+
|
| 45 |
+
logger.info(f"LTX Worker ({self.device}): Carregando modelo para a CPU...")
|
| 46 |
+
model_path = os.path.join(models_dir, self.config["checkpoint_path"])
|
| 47 |
+
if not os.path.exists(model_path):
|
| 48 |
+
model_path = huggingface_hub.hf_hub_download(
|
| 49 |
+
repo_id="Lightricks/LTX-Video", filename=self.config["checkpoint_path"],
|
| 50 |
+
local_dir=models_dir, local_dir_use_symlinks=False
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.pipeline = create_ltx_video_pipeline(
|
| 54 |
+
ckpt_path=model_path, precision=self.config["precision"],
|
| 55 |
+
text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
|
| 56 |
+
sampler=self.config["sampler"], device='cpu'
|
| 57 |
+
)
|
| 58 |
+
logger.info(f"LTX Worker ({self.device}): Modelo pronto na CPU. É um modelo destilado? {self.is_distilled}")
|
| 59 |
+
|
| 60 |
+
if self.device.type == 'cuda' and can_optimize_fp8():
|
| 61 |
+
logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
|
| 62 |
+
self.pipeline.to(self.device)
|
| 63 |
+
optimize_ltx_worker(self)
|
| 64 |
+
self.pipeline.to(self.cpu_device)
|
| 65 |
+
logger.info(f"LTX Worker ({self.device}): Otimização concluída. Modelo pronto.")
|
| 66 |
+
elif self.device.type == 'cuda':
|
| 67 |
+
logger.info(f"LTX Worker ({self.device}): Otimização FP8 não suportada ou desativada. Usando modelo padrão.")
|
| 68 |
+
|
| 69 |
+
def to_gpu(self):
|
| 70 |
+
if self.device.type == 'cpu': return
|
| 71 |
+
logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
|
| 72 |
+
self.pipeline.to(self.device)
|
| 73 |
+
|
| 74 |
+
def to_cpu(self):
|
| 75 |
+
if self.device.type == 'cpu': return
|
| 76 |
+
logger.info(f"LTX Worker: Descarregando pipeline da GPU {self.device}...")
|
| 77 |
+
self.pipeline.to('cpu')
|
| 78 |
+
gc.collect()
|
| 79 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 80 |
+
|
| 81 |
+
def generate_video_fragment_internal(self, **kwargs):
|
| 82 |
+
return self.pipeline(**kwargs).images
|
| 83 |
+
|
| 84 |
+
class LtxPoolManager:
|
| 85 |
+
def __init__(self, device_ids, ltx_config_file):
|
| 86 |
+
logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
|
| 87 |
+
self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
|
| 88 |
+
self.current_worker_index = 0
|
| 89 |
+
self.lock = threading.Lock()
|
| 90 |
+
self.last_cleanup_thread = None
|
| 91 |
+
|
| 92 |
+
def _cleanup_worker_thread(self, worker):
|
| 93 |
+
logger.info(f"LTX CLEANUP THREAD: Iniciando limpeza de {worker.device} em background...")
|
| 94 |
+
worker.to_cpu()
|
| 95 |
+
|
| 96 |
+
def _prepare_and_log_params(self, worker_to_use, **kwargs):
|
| 97 |
+
target_device = worker_to_use.device
|
| 98 |
+
height, width = kwargs['height'], kwargs['width']
|
| 99 |
+
|
| 100 |
+
conditioning_data = kwargs.get('conditioning_items_data', [])
|
| 101 |
+
final_conditioning_items = []
|
| 102 |
+
|
| 103 |
+
# --- LOG ADICIONADO: Detalhes dos tensores de condicionamento ---
|
| 104 |
+
conditioning_log_details = []
|
| 105 |
+
for i, item in enumerate(conditioning_data):
|
| 106 |
+
if hasattr(item, 'latent_tensor'):
|
| 107 |
+
item.latent_tensor = item.latent_tensor.to(target_device)
|
| 108 |
+
final_conditioning_items.append(item)
|
| 109 |
+
conditioning_log_details.append(
|
| 110 |
+
f" - Item {i}: frame={item.media_frame_number}, strength={item.conditioning_strength:.2f}, shape={list(item.latent_tensor.shape)}"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
first_pass_config = worker_to_use.config.get("first_pass", {})
|
| 114 |
+
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
|
| 115 |
+
padding_vals = calculate_padding(height, width, padded_h, padded_w)
|
| 116 |
+
|
| 117 |
+
pipeline_params = {
|
| 118 |
+
"height": padded_h, "width": padded_w,
|
| 119 |
+
"num_frames": kwargs['video_total_frames'], "frame_rate": kwargs['video_fps'],
|
| 120 |
+
"generator": torch.Generator(device=target_device).manual_seed(int(kwargs.get('seed', time.time())) + kwargs['current_fragment_index']),
|
| 121 |
+
"conditioning_items": final_conditioning_items,
|
| 122 |
+
"is_video": True, "vae_per_channel_normalize": True,
|
| 123 |
+
"decode_timestep": float(kwargs.get('decode_timestep', worker_to_use.config.get("decode_timestep", 0.05))),
|
| 124 |
+
"decode_noise_scale": float(kwargs.get('decode_noise_scale', worker_to_use.config.get("decode_noise_scale", 0.025))),
|
| 125 |
+
"image_cond_noise_scale": float(kwargs.get('image_cond_noise_scale', 0.0)),
|
| 126 |
+
"stochastic_sampling": bool(kwargs.get('stochastic_sampling', worker_to_use.config.get("stochastic_sampling", False))),
|
| 127 |
+
"prompt": kwargs['motion_prompt'],
|
| 128 |
+
"negative_prompt": kwargs.get('negative_prompt', "blurry, distorted, static, bad quality, artifacts"),
|
| 129 |
+
"guidance_scale": float(kwargs.get('guidance_scale', 1.0)),
|
| 130 |
+
"stg_scale": float(kwargs.get('stg_scale', 0.0)),
|
| 131 |
+
"rescaling_scale": float(kwargs.get('rescaling_scale', 1.0)),
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
if worker_to_use.is_distilled:
|
| 135 |
+
pipeline_params["timesteps"] = first_pass_config.get("timesteps")
|
| 136 |
+
pipeline_params["num_inference_steps"] = len(pipeline_params["timesteps"]) if "timesteps" in first_pass_config else 8
|
| 137 |
+
else:
|
| 138 |
+
pipeline_params["num_inference_steps"] = int(kwargs.get('num_inference_steps', 7))
|
| 139 |
+
|
| 140 |
+
# --- LOG ADICIONADO: Exibição completa dos parâmetros da pipeline ---
|
| 141 |
+
log_friendly_params = pipeline_params.copy()
|
| 142 |
+
log_friendly_params.pop('generator', None)
|
| 143 |
+
log_friendly_params.pop('conditioning_items', None)
|
| 144 |
+
|
| 145 |
+
logger.info("="*60)
|
| 146 |
+
logger.info(f"CHAMADA AO PIPELINE LTX NO DISPOSITIVO: {worker_to_use.device}")
|
| 147 |
+
logger.info(f"Modelo: {'Distilled' if worker_to_use.is_distilled else 'Base'}")
|
| 148 |
+
logger.info("-" * 20 + " PARÂMETROS DA PIPELINE " + "-" * 20)
|
| 149 |
+
logger.info(json.dumps(log_friendly_params, indent=2))
|
| 150 |
+
logger.info("-" * 20 + " ITENS DE CONDICIONAMENTO " + "-" * 19)
|
| 151 |
+
logger.info("\n".join(conditioning_log_details))
|
| 152 |
+
logger.info("="*60)
|
| 153 |
+
# --- FIM DO LOG ADICIONADO ---
|
| 154 |
+
|
| 155 |
+
return pipeline_params, padding_vals
|
| 156 |
+
|
| 157 |
+
def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
|
| 158 |
+
worker_to_use = None
|
| 159 |
+
progress = kwargs.get('progress')
|
| 160 |
+
try:
|
| 161 |
+
with self.lock:
|
| 162 |
+
if self.last_cleanup_thread and self.last_cleanup_thread.is_alive():
|
| 163 |
+
self.last_cleanup_thread.join()
|
| 164 |
+
worker_to_use = self.workers[self.current_worker_index]
|
| 165 |
+
previous_worker_index = (self.current_worker_index - 1 + len(self.workers)) % len(self.workers)
|
| 166 |
+
worker_to_cleanup = self.workers[previous_worker_index]
|
| 167 |
+
cleanup_thread = threading.Thread(target=self._cleanup_worker_thread, args=(worker_to_cleanup,))
|
| 168 |
+
cleanup_thread.start()
|
| 169 |
+
self.last_cleanup_thread = cleanup_thread
|
| 170 |
+
worker_to_use.to_gpu()
|
| 171 |
+
self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
|
| 172 |
+
|
| 173 |
+
pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
|
| 174 |
+
pipeline_params['output_type'] = "latent"
|
| 175 |
+
|
| 176 |
+
if progress: progress(0.1, desc=f"[Especialista LTX em {worker_to_use.device}] Gerando latentes...")
|
| 177 |
+
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
result_tensor = worker_to_use.generate_video_fragment_internal(**pipeline_params)
|
| 180 |
+
|
| 181 |
+
return result_tensor, padding_vals
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"LTX POOL MANAGER: Erro durante a geração de latentes: {e}", exc_info=True)
|
| 184 |
+
raise e
|
| 185 |
+
finally:
|
| 186 |
+
if worker_to_use:
|
| 187 |
+
logger.info(f"LTX POOL MANAGER: Executando limpeza final para {worker_to_use.device}...")
|
| 188 |
+
worker_to_use.to_cpu()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
logger.info("Lendo config.yaml para inicializar o LTX Pool Manager...")
|
| 192 |
+
with open("config.yaml", 'r') as f:
|
| 193 |
+
config = yaml.safe_load(f)
|
| 194 |
+
ltx_gpus_required = config['specialists']['ltx']['gpus_required']
|
| 195 |
+
ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
|
| 196 |
+
ltx_config_path = config['specialists']['ltx']['config_file']
|
| 197 |
+
ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file=ltx_config_path)
|
| 198 |
+
logger.info("Especialista de Vídeo (LTX) pronto.")
|
ltx_video/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
ltx_video/README.md
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🛠️ helpers/ - Ferramentas de IA de Terceiros Adaptadas para ADUC-SDR
|
| 2 |
+
|
| 3 |
+
Esta pasta contém implementações adaptadas de modelos e utilitários de IA de terceiros, que servem como "especialistas" ou "ferramentas" de baixo nível para a arquitetura ADUC-SDR.
|
| 4 |
+
|
| 5 |
+
**IMPORTANTE:** O conteúdo desta pasta é de autoria de seus respectivos idealizadores e desenvolvedores originais. Esta pasta **NÃO FAZ PARTE** do projeto principal ADUC-SDR em termos de sua arquitetura inovadora. Ela serve como um repositório para as **dependências diretas e modificadas** que os `DeformesXDEngines` (os estágios do "foguete" ADUC-SDR) invocam para realizar tarefas específicas (geração de imagem, vídeo, áudio).
|
| 6 |
+
|
| 7 |
+
As modificações realizadas nos arquivos aqui presentes visam principalmente:
|
| 8 |
+
1. **Adaptação de Interfaces:** Padronizar as interfaces para que se encaixem no fluxo de orquestração do ADUC-SDR.
|
| 9 |
+
2. **Gerenciamento de Recursos:** Integrar lógicas de carregamento/descarregamento de modelos (GPU management) e configurações via arquivos YAML.
|
| 10 |
+
3. **Otimização de Fluxo:** Ajustar as pipelines para aceitar formatos de entrada mais eficientes (ex: tensores pré-codificados em vez de caminhos de mídia, pulando etapas de codificação/decodificação redundantes).
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 📄 Licenciamento
|
| 15 |
+
|
| 16 |
+
O conteúdo original dos projetos listados abaixo é licenciado sob a **Licença Apache 2.0**, ou outra licença especificada pelos autores originais. Todas as modificações e o uso desses arquivos dentro da estrutura `helpers/` do projeto ADUC-SDR estão em conformidade com os termos da **Licença Apache 2.0**.
|
| 17 |
+
|
| 18 |
+
As licenças originais dos projetos podem ser encontradas nas suas respectivas fontes ou nos subdiretórios `incl_licenses/` dentro de cada módulo adaptado.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## 🛠️ API dos Helpers e Guia de Uso
|
| 23 |
+
|
| 24 |
+
Esta seção detalha como cada helper (agente especialista) deve ser utilizado dentro do ecossistema ADUC-SDR. Todos os agentes são instanciados como **singletons** no `hardware_manager.py` para garantir o gerenciamento centralizado de recursos de GPU.
|
| 25 |
+
|
| 26 |
+
### **gemini_helpers.py (GeminiAgent)**
|
| 27 |
+
|
| 28 |
+
* **Propósito:** Atua como o "Oráculo de Síntese Adaptativo", responsável por todas as tarefas de processamento de linguagem natural, como criação de storyboards, geração de prompts, e tomada de decisões narrativas.
|
| 29 |
+
* **Singleton Instance:** `gemini_agent_singleton`
|
| 30 |
+
* **Construtor:** `GeminiAgent()`
|
| 31 |
+
* Lê `configs/gemini_config.yaml` para obter o nome do modelo, parâmetros de inferência e caminhos de templates de prompt. A chave da API é lida da variável de ambiente `GEMINI_API_KEY`.
|
| 32 |
+
* **Métodos Públicos:**
|
| 33 |
+
* `generate_storyboard(prompt: str, num_keyframes: int, ref_image_paths: list[str])`
|
| 34 |
+
* **Inputs:**
|
| 35 |
+
* `prompt`: A ideia geral do filme (string).
|
| 36 |
+
* `num_keyframes`: O número de cenas a serem geradas (int).
|
| 37 |
+
* `ref_image_paths`: Lista de caminhos para as imagens de referência (list[str]).
|
| 38 |
+
* **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de strings do storyboard e um relatório textual da operação).
|
| 39 |
+
* `select_keyframes_from_pool(storyboard: list, base_image_paths: list[str], pool_image_paths: list[str])`
|
| 40 |
+
* **Inputs:**
|
| 41 |
+
* `storyboard`: A lista de strings do storyboard gerado.
|
| 42 |
+
* `base_image_paths`: Imagens de referência base (list[str]).
|
| 43 |
+
* `pool_image_paths`: O "banco de imagens" de onde selecionar (list[str]).
|
| 44 |
+
* **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de caminhos de imagens selecionadas e um relatório textual).
|
| 45 |
+
* `get_anticipatory_keyframe_prompt(...)`
|
| 46 |
+
* **Inputs:** Contexto narrativo e visual para gerar um prompt de imagem.
|
| 47 |
+
* **Output:** `tuple[str, str]` (Uma tupla contendo o prompt gerado para o modelo de imagem e um relatório textual).
|
| 48 |
+
* `get_initial_motion_prompt(...)`
|
| 49 |
+
* **Inputs:** Contexto narrativo e visual para a primeira transição de vídeo.
|
| 50 |
+
* **Output:** `tuple[str, str]` (Uma tupla contendo o prompt de movimento gerado e um relatório textual).
|
| 51 |
+
* `get_transition_decision(...)`
|
| 52 |
+
* **Inputs:** Contexto narrativo e visual para uma transição de vídeo intermediária.
|
| 53 |
+
* **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"transition_type": "...", "motion_prompt": "..."}` e um relatório textual).
|
| 54 |
+
* `generate_audio_prompts(...)`
|
| 55 |
+
* **Inputs:** Contexto narrativo global.
|
| 56 |
+
* **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"music_prompt": "...", "sfx_prompt": "..."}` e um relatório textual).
|
| 57 |
+
|
| 58 |
+
### **flux_kontext_helpers.py (FluxPoolManager)**
|
| 59 |
+
|
| 60 |
+
* **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline FluxKontext. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
|
| 61 |
+
* **Singleton Instance:** `flux_kontext_singleton`
|
| 62 |
+
* **Construtor:** `FluxPoolManager(device_ids: list[str], flux_config_file: str)`
|
| 63 |
+
* Lê `configs/flux_config.yaml`.
|
| 64 |
+
* **Método Público:**
|
| 65 |
+
* `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int, seed: int = 42, callback: callable = None)`
|
| 66 |
+
* **Inputs:**
|
| 67 |
+
* `prompt`: Prompt textual para guiar a geração (string).
|
| 68 |
+
* `reference_images`: Lista de objetos `PIL.Image` como referência visual.
|
| 69 |
+
* `width`, `height`: Dimensões da imagem de saída (int).
|
| 70 |
+
* `seed`: Semente para reprodutibilidade (int).
|
| 71 |
+
* `callback`: Função de callback opcional para monitorar o progresso.
|
| 72 |
+
* **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
|
| 73 |
+
|
| 74 |
+
### **dreamo_helpers.py (DreamOAgent)**
|
| 75 |
+
|
| 76 |
+
* **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline DreamO, com capacidades avançadas de edição e estilo a partir de referências.
|
| 77 |
+
* **Singleton Instance:** `dreamo_agent_singleton`
|
| 78 |
+
* **Construtor:** `DreamOAgent(device_id: str = None)`
|
| 79 |
+
* Lê `configs/dreamo_config.yaml`.
|
| 80 |
+
* **Método Público:**
|
| 81 |
+
* `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int)`
|
| 82 |
+
* **Inputs:**
|
| 83 |
+
* `prompt`: Prompt textual para guiar a geração (string).
|
| 84 |
+
* `reference_images`: Lista de objetos `PIL.Image` como referência visual. A lógica interna atribui a primeira imagem como `style` e as demais como `ip`.
|
| 85 |
+
* `width`, `height`: Dimensões da imagem de saída (int).
|
| 86 |
+
* **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
|
| 87 |
+
|
| 88 |
+
### **ltx_manager_helpers.py (LtxPoolManager)**
|
| 89 |
+
|
| 90 |
+
* **Propósito:** Especialista na geração de fragmentos de vídeo no espaço latente usando a pipeline LTX-Video. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
|
| 91 |
+
* **Singleton Instance:** `ltx_manager_singleton`
|
| 92 |
+
* **Construtor:** `LtxPoolManager(device_ids: list[str], ltx_model_config_file: str, ltx_global_config_file: str)`
|
| 93 |
+
* Lê o `ltx_global_config_file` e o `ltx_model_config_file` para configurar a pipeline.
|
| 94 |
+
* **Método Público:**
|
| 95 |
+
* `generate_latent_fragment(**kwargs)`
|
| 96 |
+
* **Inputs:** Dicionário de keyword arguments (`kwargs`) contendo todos os parâmetros da pipeline LTX, incluindo:
|
| 97 |
+
* `height`, `width`: Dimensões do vídeo (int).
|
| 98 |
+
* `video_total_frames`: Número total de frames a serem gerados (int).
|
| 99 |
+
* `video_fps`: Frames por segundo (int).
|
| 100 |
+
* `motion_prompt`: Prompt de movimento (string).
|
| 101 |
+
* `conditioning_items_data`: Lista de objetos `LatentConditioningItem` contendo os tensores latentes de condição.
|
| 102 |
+
* `guidance_scale`, `stg_scale`, `num_inference_steps`, etc.
|
| 103 |
+
* **Output:** `tuple[torch.Tensor, tuple]` (Uma tupla contendo o tensor latente gerado e os valores de padding utilizados).
|
| 104 |
+
|
| 105 |
+
### **mmaudio_helper.py (MMAudioAgent)**
|
| 106 |
+
|
| 107 |
+
* **Propósito:** Especialista em geração de áudio para um determinado fragmento de vídeo.
|
| 108 |
+
* **Singleton Instance:** `mmaudio_agent_singleton`
|
| 109 |
+
* **Construtor:** `MMAudioAgent(workspace_dir: str, device_id: str = None, mmaudio_config_file: str)`
|
| 110 |
+
* Lê `configs/mmaudio_config.yaml`.
|
| 111 |
+
* **Método Público:**
|
| 112 |
+
* `generate_audio_for_video(video_path: str, prompt: str, negative_prompt: str, duration_seconds: float)`
|
| 113 |
+
* **Inputs:**
|
| 114 |
+
* `video_path`: Caminho para o arquivo de vídeo silencioso (string).
|
| 115 |
+
* `prompt`: Prompt textual para guiar a geração de áudio (string).
|
| 116 |
+
* `negative_prompt`: Prompt negativo para áudio (string).
|
| 117 |
+
* `duration_seconds`: Duração exata do vídeo (float).
|
| 118 |
+
* **Output:** `str` (O caminho para o novo arquivo de vídeo com a faixa de áudio integrada).
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## 🔗 Projetos Originais e Atribuições
|
| 123 |
+
(A seção de atribuições e licenças permanece a mesma que definimos anteriormente)
|
| 124 |
+
|
| 125 |
+
### DreamO
|
| 126 |
+
* **Repositório Original:** [https://github.com/bytedance/DreamO](https://github.com/bytedance/DreamO)
|
| 127 |
+
...
|
| 128 |
+
|
| 129 |
+
### LTX-Video
|
| 130 |
+
* **Repositório Original:** [https://github.com/Lightricks/LTX-Video](https://github.com/Lightricks/LTX-Video)
|
| 131 |
+
...
|
| 132 |
+
|
| 133 |
+
### MMAudio
|
| 134 |
+
* **Repositório Original:** [https://github.com/hkchengrex/MMAudio](https://github.com/hkchengrex/MMAudio)
|
| 135 |
+
...
|
ltx_video/__init__.py
ADDED
|
File without changes
|
ltx_video/models/__init__.py
ADDED
|
File without changes
|
ltx_video/models/autoencoders/__init__.py
ADDED
|
File without changes
|
ltx_video/models/autoencoders/causal_conv3d.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalConv3d(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
in_channels,
|
| 11 |
+
out_channels,
|
| 12 |
+
kernel_size: int = 3,
|
| 13 |
+
stride: Union[int, Tuple[int]] = 1,
|
| 14 |
+
dilation: int = 1,
|
| 15 |
+
groups: int = 1,
|
| 16 |
+
spatial_padding_mode: str = "zeros",
|
| 17 |
+
**kwargs,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.in_channels = in_channels
|
| 22 |
+
self.out_channels = out_channels
|
| 23 |
+
|
| 24 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 25 |
+
self.time_kernel_size = kernel_size[0]
|
| 26 |
+
|
| 27 |
+
dilation = (dilation, 1, 1)
|
| 28 |
+
|
| 29 |
+
height_pad = kernel_size[1] // 2
|
| 30 |
+
width_pad = kernel_size[2] // 2
|
| 31 |
+
padding = (0, height_pad, width_pad)
|
| 32 |
+
|
| 33 |
+
self.conv = nn.Conv3d(
|
| 34 |
+
in_channels,
|
| 35 |
+
out_channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
stride=stride,
|
| 38 |
+
dilation=dilation,
|
| 39 |
+
padding=padding,
|
| 40 |
+
padding_mode=spatial_padding_mode,
|
| 41 |
+
groups=groups,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, causal: bool = True):
|
| 45 |
+
if causal:
|
| 46 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 47 |
+
(1, 1, self.time_kernel_size - 1, 1, 1)
|
| 48 |
+
)
|
| 49 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 50 |
+
else:
|
| 51 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 52 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 53 |
+
)
|
| 54 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
| 55 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 56 |
+
)
|
| 57 |
+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 58 |
+
x = self.conv(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def weight(self):
|
| 63 |
+
return self.conv.weight
|
ltx_video/models/autoencoders/causal_video_autoencoder.py
ADDED
|
@@ -0,0 +1,1398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Any, Mapping, Optional, Tuple, Union, List
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch import nn
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 15 |
+
from safetensors import safe_open
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 19 |
+
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
|
| 20 |
+
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
|
| 21 |
+
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
|
| 22 |
+
from ltx_video.models.transformers.attention import Attention
|
| 23 |
+
from ltx_video.utils.diffusers_config_mapping import (
|
| 24 |
+
diffusers_and_ours_config_mapping,
|
| 25 |
+
make_hashable_key,
|
| 26 |
+
VAE_KEYS_RENAME_DICT,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
|
| 30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 34 |
+
@classmethod
|
| 35 |
+
def from_pretrained(
|
| 36 |
+
cls,
|
| 37 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 38 |
+
*args,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 42 |
+
if (
|
| 43 |
+
pretrained_model_name_or_path.is_dir()
|
| 44 |
+
and (pretrained_model_name_or_path / "autoencoder.pth").exists()
|
| 45 |
+
):
|
| 46 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 47 |
+
config = cls.load_config(config_local_path, **kwargs)
|
| 48 |
+
|
| 49 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
| 50 |
+
state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
| 51 |
+
|
| 52 |
+
statistics_local_path = (
|
| 53 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 54 |
+
)
|
| 55 |
+
if statistics_local_path.exists():
|
| 56 |
+
with open(statistics_local_path, "r") as file:
|
| 57 |
+
data = json.load(file)
|
| 58 |
+
transposed_data = list(zip(*data["data"]))
|
| 59 |
+
data_dict = {
|
| 60 |
+
col: torch.tensor(vals)
|
| 61 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 62 |
+
}
|
| 63 |
+
std_of_means = data_dict["std-of-means"]
|
| 64 |
+
mean_of_means = data_dict.get(
|
| 65 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 66 |
+
)
|
| 67 |
+
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
|
| 68 |
+
std_of_means
|
| 69 |
+
)
|
| 70 |
+
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
|
| 71 |
+
mean_of_means
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
elif pretrained_model_name_or_path.is_dir():
|
| 75 |
+
config_path = pretrained_model_name_or_path / "vae" / "config.json"
|
| 76 |
+
with open(config_path, "r") as f:
|
| 77 |
+
config = make_hashable_key(json.load(f))
|
| 78 |
+
|
| 79 |
+
assert config in diffusers_and_ours_config_mapping, (
|
| 80 |
+
"Provided diffusers checkpoint config for VAE is not suppported. "
|
| 81 |
+
"We only support diffusers configs found in Lightricks/LTX-Video."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
config = diffusers_and_ours_config_mapping[config]
|
| 85 |
+
|
| 86 |
+
state_dict_path = (
|
| 87 |
+
pretrained_model_name_or_path
|
| 88 |
+
/ "vae"
|
| 89 |
+
/ "diffusion_pytorch_model.safetensors"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
state_dict = {}
|
| 93 |
+
with safe_open(state_dict_path, framework="pt", device="cpu") as f:
|
| 94 |
+
for k in f.keys():
|
| 95 |
+
state_dict[k] = f.get_tensor(k)
|
| 96 |
+
for key in list(state_dict.keys()):
|
| 97 |
+
new_key = key
|
| 98 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 99 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 100 |
+
|
| 101 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 102 |
+
|
| 103 |
+
elif pretrained_model_name_or_path.is_file() and str(
|
| 104 |
+
pretrained_model_name_or_path
|
| 105 |
+
).endswith(".safetensors"):
|
| 106 |
+
state_dict = {}
|
| 107 |
+
with safe_open(
|
| 108 |
+
pretrained_model_name_or_path, framework="pt", device="cpu"
|
| 109 |
+
) as f:
|
| 110 |
+
metadata = f.metadata()
|
| 111 |
+
for k in f.keys():
|
| 112 |
+
state_dict[k] = f.get_tensor(k)
|
| 113 |
+
configs = json.loads(metadata["config"])
|
| 114 |
+
config = configs["vae"]
|
| 115 |
+
|
| 116 |
+
video_vae = cls.from_config(config)
|
| 117 |
+
if "torch_dtype" in kwargs:
|
| 118 |
+
video_vae.to(kwargs["torch_dtype"])
|
| 119 |
+
video_vae.load_state_dict(state_dict)
|
| 120 |
+
return video_vae
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def from_config(config):
|
| 124 |
+
assert (
|
| 125 |
+
config["_class_name"] == "CausalVideoAutoencoder"
|
| 126 |
+
), "config must have _class_name=CausalVideoAutoencoder"
|
| 127 |
+
if isinstance(config["dims"], list):
|
| 128 |
+
config["dims"] = tuple(config["dims"])
|
| 129 |
+
|
| 130 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 131 |
+
|
| 132 |
+
double_z = config.get("double_z", True)
|
| 133 |
+
latent_log_var = config.get(
|
| 134 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 135 |
+
)
|
| 136 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
| 137 |
+
normalize_latent_channels = config.get("normalize_latent_channels", False)
|
| 138 |
+
|
| 139 |
+
if use_quant_conv and latent_log_var in ["uniform", "constant"]:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"latent_log_var={latent_log_var} requires use_quant_conv=False"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
encoder = Encoder(
|
| 145 |
+
dims=config["dims"],
|
| 146 |
+
in_channels=config.get("in_channels", 3),
|
| 147 |
+
out_channels=config["latent_channels"],
|
| 148 |
+
blocks=config.get("encoder_blocks", config.get("blocks")),
|
| 149 |
+
patch_size=config.get("patch_size", 1),
|
| 150 |
+
latent_log_var=latent_log_var,
|
| 151 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 152 |
+
base_channels=config.get("encoder_base_channels", 128),
|
| 153 |
+
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
decoder = Decoder(
|
| 157 |
+
dims=config["dims"],
|
| 158 |
+
in_channels=config["latent_channels"],
|
| 159 |
+
out_channels=config.get("out_channels", 3),
|
| 160 |
+
blocks=config.get("decoder_blocks", config.get("blocks")),
|
| 161 |
+
patch_size=config.get("patch_size", 1),
|
| 162 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 163 |
+
causal=config.get("causal_decoder", False),
|
| 164 |
+
timestep_conditioning=config.get("timestep_conditioning", False),
|
| 165 |
+
base_channels=config.get("decoder_base_channels", 128),
|
| 166 |
+
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
dims = config["dims"]
|
| 170 |
+
return CausalVideoAutoencoder(
|
| 171 |
+
encoder=encoder,
|
| 172 |
+
decoder=decoder,
|
| 173 |
+
latent_channels=config["latent_channels"],
|
| 174 |
+
dims=dims,
|
| 175 |
+
use_quant_conv=use_quant_conv,
|
| 176 |
+
normalize_latent_channels=normalize_latent_channels,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def config(self):
|
| 181 |
+
return SimpleNamespace(
|
| 182 |
+
_class_name="CausalVideoAutoencoder",
|
| 183 |
+
dims=self.dims,
|
| 184 |
+
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
| 185 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 186 |
+
// self.decoder.patch_size**2,
|
| 187 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
| 188 |
+
encoder_blocks=self.encoder.blocks_desc,
|
| 189 |
+
decoder_blocks=self.decoder.blocks_desc,
|
| 190 |
+
scaling_factor=1.0,
|
| 191 |
+
norm_layer=self.encoder.norm_layer,
|
| 192 |
+
patch_size=self.encoder.patch_size,
|
| 193 |
+
latent_log_var=self.encoder.latent_log_var,
|
| 194 |
+
use_quant_conv=self.use_quant_conv,
|
| 195 |
+
causal_decoder=self.decoder.causal,
|
| 196 |
+
timestep_conditioning=self.decoder.timestep_conditioning,
|
| 197 |
+
normalize_latent_channels=self.normalize_latent_channels,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def is_video_supported(self):
|
| 202 |
+
"""
|
| 203 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
| 204 |
+
"""
|
| 205 |
+
return self.dims != 2
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def spatial_downscale_factor(self):
|
| 209 |
+
return (
|
| 210 |
+
2
|
| 211 |
+
** len(
|
| 212 |
+
[
|
| 213 |
+
block
|
| 214 |
+
for block in self.encoder.blocks_desc
|
| 215 |
+
if block[0]
|
| 216 |
+
in [
|
| 217 |
+
"compress_space",
|
| 218 |
+
"compress_all",
|
| 219 |
+
"compress_all_res",
|
| 220 |
+
"compress_space_res",
|
| 221 |
+
]
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
* self.encoder.patch_size
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def temporal_downscale_factor(self):
|
| 229 |
+
return 2 ** len(
|
| 230 |
+
[
|
| 231 |
+
block
|
| 232 |
+
for block in self.encoder.blocks_desc
|
| 233 |
+
if block[0]
|
| 234 |
+
in [
|
| 235 |
+
"compress_time",
|
| 236 |
+
"compress_all",
|
| 237 |
+
"compress_all_res",
|
| 238 |
+
"compress_time_res",
|
| 239 |
+
]
|
| 240 |
+
]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
def to_json_string(self) -> str:
|
| 244 |
+
import json
|
| 245 |
+
|
| 246 |
+
return json.dumps(self.config.__dict__)
|
| 247 |
+
|
| 248 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 249 |
+
if any([key.startswith("vae.") for key in state_dict.keys()]):
|
| 250 |
+
state_dict = {
|
| 251 |
+
key.replace("vae.", ""): value
|
| 252 |
+
for key, value in state_dict.items()
|
| 253 |
+
if key.startswith("vae.")
|
| 254 |
+
}
|
| 255 |
+
ckpt_state_dict = {
|
| 256 |
+
key: value
|
| 257 |
+
for key, value in state_dict.items()
|
| 258 |
+
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
model_keys = set(name for name, _ in self.named_modules())
|
| 262 |
+
|
| 263 |
+
key_mapping = {
|
| 264 |
+
".resnets.": ".res_blocks.",
|
| 265 |
+
"downsamplers.0": "downsample",
|
| 266 |
+
"upsamplers.0": "upsample",
|
| 267 |
+
}
|
| 268 |
+
converted_state_dict = {}
|
| 269 |
+
for key, value in ckpt_state_dict.items():
|
| 270 |
+
for k, v in key_mapping.items():
|
| 271 |
+
key = key.replace(k, v)
|
| 272 |
+
|
| 273 |
+
key_prefix = ".".join(key.split(".")[:-1])
|
| 274 |
+
if "norm" in key and key_prefix not in model_keys:
|
| 275 |
+
logger.info(
|
| 276 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 277 |
+
)
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
converted_state_dict[key] = value
|
| 281 |
+
|
| 282 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
| 283 |
+
|
| 284 |
+
data_dict = {
|
| 285 |
+
key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
|
| 286 |
+
for key, value in state_dict.items()
|
| 287 |
+
if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
| 288 |
+
}
|
| 289 |
+
if len(data_dict) > 0:
|
| 290 |
+
self.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 291 |
+
self.register_buffer(
|
| 292 |
+
"mean_of_means",
|
| 293 |
+
data_dict.get(
|
| 294 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 295 |
+
),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def last_layer(self):
|
| 299 |
+
if hasattr(self.decoder, "conv_out"):
|
| 300 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
| 301 |
+
last_layer = self.decoder.conv_out[-1]
|
| 302 |
+
else:
|
| 303 |
+
last_layer = self.decoder.conv_out
|
| 304 |
+
else:
|
| 305 |
+
last_layer = self.decoder.layers[-1]
|
| 306 |
+
return last_layer
|
| 307 |
+
|
| 308 |
+
def set_use_tpu_flash_attention(self):
|
| 309 |
+
for block in self.decoder.up_blocks:
|
| 310 |
+
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
|
| 311 |
+
for attention_block in block.attention_blocks:
|
| 312 |
+
attention_block.set_use_tpu_flash_attention()
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class Encoder(nn.Module):
|
| 316 |
+
r"""
|
| 317 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
| 321 |
+
The number of dimensions to use in convolutions.
|
| 322 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 323 |
+
The number of input channels.
|
| 324 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 325 |
+
The number of output channels.
|
| 326 |
+
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
| 327 |
+
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
| 328 |
+
base_channels (`int`, *optional*, defaults to 128):
|
| 329 |
+
The number of output channels for the first convolutional layer.
|
| 330 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 331 |
+
The number of groups for normalization.
|
| 332 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 333 |
+
The patch size to use. Should be a power of 2.
|
| 334 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 335 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 336 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
| 337 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
| 343 |
+
in_channels: int = 3,
|
| 344 |
+
out_channels: int = 3,
|
| 345 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 346 |
+
base_channels: int = 128,
|
| 347 |
+
norm_num_groups: int = 32,
|
| 348 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
| 349 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
| 350 |
+
latent_log_var: str = "per_channel",
|
| 351 |
+
spatial_padding_mode: str = "zeros",
|
| 352 |
+
):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.patch_size = patch_size
|
| 355 |
+
self.norm_layer = norm_layer
|
| 356 |
+
self.latent_channels = out_channels
|
| 357 |
+
self.latent_log_var = latent_log_var
|
| 358 |
+
self.blocks_desc = blocks
|
| 359 |
+
|
| 360 |
+
in_channels = in_channels * patch_size**2
|
| 361 |
+
output_channel = base_channels
|
| 362 |
+
|
| 363 |
+
self.conv_in = make_conv_nd(
|
| 364 |
+
dims=dims,
|
| 365 |
+
in_channels=in_channels,
|
| 366 |
+
out_channels=output_channel,
|
| 367 |
+
kernel_size=3,
|
| 368 |
+
stride=1,
|
| 369 |
+
padding=1,
|
| 370 |
+
causal=True,
|
| 371 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.down_blocks = nn.ModuleList([])
|
| 375 |
+
|
| 376 |
+
for block_name, block_params in blocks:
|
| 377 |
+
input_channel = output_channel
|
| 378 |
+
if isinstance(block_params, int):
|
| 379 |
+
block_params = {"num_layers": block_params}
|
| 380 |
+
|
| 381 |
+
if block_name == "res_x":
|
| 382 |
+
block = UNetMidBlock3D(
|
| 383 |
+
dims=dims,
|
| 384 |
+
in_channels=input_channel,
|
| 385 |
+
num_layers=block_params["num_layers"],
|
| 386 |
+
resnet_eps=1e-6,
|
| 387 |
+
resnet_groups=norm_num_groups,
|
| 388 |
+
norm_layer=norm_layer,
|
| 389 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 390 |
+
)
|
| 391 |
+
elif block_name == "res_x_y":
|
| 392 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 393 |
+
block = ResnetBlock3D(
|
| 394 |
+
dims=dims,
|
| 395 |
+
in_channels=input_channel,
|
| 396 |
+
out_channels=output_channel,
|
| 397 |
+
eps=1e-6,
|
| 398 |
+
groups=norm_num_groups,
|
| 399 |
+
norm_layer=norm_layer,
|
| 400 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 401 |
+
)
|
| 402 |
+
elif block_name == "compress_time":
|
| 403 |
+
block = make_conv_nd(
|
| 404 |
+
dims=dims,
|
| 405 |
+
in_channels=input_channel,
|
| 406 |
+
out_channels=output_channel,
|
| 407 |
+
kernel_size=3,
|
| 408 |
+
stride=(2, 1, 1),
|
| 409 |
+
causal=True,
|
| 410 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 411 |
+
)
|
| 412 |
+
elif block_name == "compress_space":
|
| 413 |
+
block = make_conv_nd(
|
| 414 |
+
dims=dims,
|
| 415 |
+
in_channels=input_channel,
|
| 416 |
+
out_channels=output_channel,
|
| 417 |
+
kernel_size=3,
|
| 418 |
+
stride=(1, 2, 2),
|
| 419 |
+
causal=True,
|
| 420 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 421 |
+
)
|
| 422 |
+
elif block_name == "compress_all":
|
| 423 |
+
block = make_conv_nd(
|
| 424 |
+
dims=dims,
|
| 425 |
+
in_channels=input_channel,
|
| 426 |
+
out_channels=output_channel,
|
| 427 |
+
kernel_size=3,
|
| 428 |
+
stride=(2, 2, 2),
|
| 429 |
+
causal=True,
|
| 430 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 431 |
+
)
|
| 432 |
+
elif block_name == "compress_all_x_y":
|
| 433 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 434 |
+
block = make_conv_nd(
|
| 435 |
+
dims=dims,
|
| 436 |
+
in_channels=input_channel,
|
| 437 |
+
out_channels=output_channel,
|
| 438 |
+
kernel_size=3,
|
| 439 |
+
stride=(2, 2, 2),
|
| 440 |
+
causal=True,
|
| 441 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 442 |
+
)
|
| 443 |
+
elif block_name == "compress_all_res":
|
| 444 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 445 |
+
block = SpaceToDepthDownsample(
|
| 446 |
+
dims=dims,
|
| 447 |
+
in_channels=input_channel,
|
| 448 |
+
out_channels=output_channel,
|
| 449 |
+
stride=(2, 2, 2),
|
| 450 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 451 |
+
)
|
| 452 |
+
elif block_name == "compress_space_res":
|
| 453 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 454 |
+
block = SpaceToDepthDownsample(
|
| 455 |
+
dims=dims,
|
| 456 |
+
in_channels=input_channel,
|
| 457 |
+
out_channels=output_channel,
|
| 458 |
+
stride=(1, 2, 2),
|
| 459 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 460 |
+
)
|
| 461 |
+
elif block_name == "compress_time_res":
|
| 462 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 463 |
+
block = SpaceToDepthDownsample(
|
| 464 |
+
dims=dims,
|
| 465 |
+
in_channels=input_channel,
|
| 466 |
+
out_channels=output_channel,
|
| 467 |
+
stride=(2, 1, 1),
|
| 468 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 469 |
+
)
|
| 470 |
+
else:
|
| 471 |
+
raise ValueError(f"unknown block: {block_name}")
|
| 472 |
+
|
| 473 |
+
self.down_blocks.append(block)
|
| 474 |
+
|
| 475 |
+
# out
|
| 476 |
+
if norm_layer == "group_norm":
|
| 477 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 478 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 479 |
+
)
|
| 480 |
+
elif norm_layer == "pixel_norm":
|
| 481 |
+
self.conv_norm_out = PixelNorm()
|
| 482 |
+
elif norm_layer == "layer_norm":
|
| 483 |
+
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 484 |
+
|
| 485 |
+
self.conv_act = nn.SiLU()
|
| 486 |
+
|
| 487 |
+
conv_out_channels = out_channels
|
| 488 |
+
if latent_log_var == "per_channel":
|
| 489 |
+
conv_out_channels *= 2
|
| 490 |
+
elif latent_log_var == "uniform":
|
| 491 |
+
conv_out_channels += 1
|
| 492 |
+
elif latent_log_var == "constant":
|
| 493 |
+
conv_out_channels += 1
|
| 494 |
+
elif latent_log_var != "none":
|
| 495 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 496 |
+
self.conv_out = make_conv_nd(
|
| 497 |
+
dims,
|
| 498 |
+
output_channel,
|
| 499 |
+
conv_out_channels,
|
| 500 |
+
3,
|
| 501 |
+
padding=1,
|
| 502 |
+
causal=True,
|
| 503 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
self.gradient_checkpointing = False
|
| 507 |
+
|
| 508 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
| 509 |
+
r"""The forward method of the `Encoder` class."""
|
| 510 |
+
|
| 511 |
+
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 512 |
+
sample = self.conv_in(sample)
|
| 513 |
+
|
| 514 |
+
checkpoint_fn = (
|
| 515 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 516 |
+
if self.gradient_checkpointing and self.training
|
| 517 |
+
else lambda x: x
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
for down_block in self.down_blocks:
|
| 521 |
+
sample = checkpoint_fn(down_block)(sample)
|
| 522 |
+
|
| 523 |
+
sample = self.conv_norm_out(sample)
|
| 524 |
+
sample = self.conv_act(sample)
|
| 525 |
+
sample = self.conv_out(sample)
|
| 526 |
+
|
| 527 |
+
if self.latent_log_var == "uniform":
|
| 528 |
+
last_channel = sample[:, -1:, ...]
|
| 529 |
+
num_dims = sample.dim()
|
| 530 |
+
|
| 531 |
+
if num_dims == 4:
|
| 532 |
+
# For shape (B, C, H, W)
|
| 533 |
+
repeated_last_channel = last_channel.repeat(
|
| 534 |
+
1, sample.shape[1] - 2, 1, 1
|
| 535 |
+
)
|
| 536 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 537 |
+
elif num_dims == 5:
|
| 538 |
+
# For shape (B, C, F, H, W)
|
| 539 |
+
repeated_last_channel = last_channel.repeat(
|
| 540 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 541 |
+
)
|
| 542 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 543 |
+
else:
|
| 544 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 545 |
+
elif self.latent_log_var == "constant":
|
| 546 |
+
sample = sample[:, :-1, ...]
|
| 547 |
+
approx_ln_0 = (
|
| 548 |
+
-30
|
| 549 |
+
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
| 550 |
+
sample = torch.cat(
|
| 551 |
+
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
| 552 |
+
dim=1,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
return sample
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class Decoder(nn.Module):
|
| 559 |
+
r"""
|
| 560 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
| 564 |
+
The number of dimensions to use in convolutions.
|
| 565 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 566 |
+
The number of input channels.
|
| 567 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 568 |
+
The number of output channels.
|
| 569 |
+
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
| 570 |
+
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
| 571 |
+
base_channels (`int`, *optional*, defaults to 128):
|
| 572 |
+
The number of output channels for the first convolutional layer.
|
| 573 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 574 |
+
The number of groups for normalization.
|
| 575 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 576 |
+
The patch size to use. Should be a power of 2.
|
| 577 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 578 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 579 |
+
causal (`bool`, *optional*, defaults to `True`):
|
| 580 |
+
Whether to use causal convolutions or not.
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
def __init__(
|
| 584 |
+
self,
|
| 585 |
+
dims,
|
| 586 |
+
in_channels: int = 3,
|
| 587 |
+
out_channels: int = 3,
|
| 588 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 589 |
+
base_channels: int = 128,
|
| 590 |
+
layers_per_block: int = 2,
|
| 591 |
+
norm_num_groups: int = 32,
|
| 592 |
+
patch_size: int = 1,
|
| 593 |
+
norm_layer: str = "group_norm",
|
| 594 |
+
causal: bool = True,
|
| 595 |
+
timestep_conditioning: bool = False,
|
| 596 |
+
spatial_padding_mode: str = "zeros",
|
| 597 |
+
):
|
| 598 |
+
super().__init__()
|
| 599 |
+
self.patch_size = patch_size
|
| 600 |
+
self.layers_per_block = layers_per_block
|
| 601 |
+
out_channels = out_channels * patch_size**2
|
| 602 |
+
self.causal = causal
|
| 603 |
+
self.blocks_desc = blocks
|
| 604 |
+
|
| 605 |
+
# Compute output channel to be product of all channel-multiplier blocks
|
| 606 |
+
output_channel = base_channels
|
| 607 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 608 |
+
block_params = block_params if isinstance(block_params, dict) else {}
|
| 609 |
+
if block_name == "res_x_y":
|
| 610 |
+
output_channel = output_channel * block_params.get("multiplier", 2)
|
| 611 |
+
if block_name.startswith("compress"):
|
| 612 |
+
output_channel = output_channel * block_params.get("multiplier", 1)
|
| 613 |
+
|
| 614 |
+
self.conv_in = make_conv_nd(
|
| 615 |
+
dims,
|
| 616 |
+
in_channels,
|
| 617 |
+
output_channel,
|
| 618 |
+
kernel_size=3,
|
| 619 |
+
stride=1,
|
| 620 |
+
padding=1,
|
| 621 |
+
causal=True,
|
| 622 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
self.up_blocks = nn.ModuleList([])
|
| 626 |
+
|
| 627 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 628 |
+
input_channel = output_channel
|
| 629 |
+
if isinstance(block_params, int):
|
| 630 |
+
block_params = {"num_layers": block_params}
|
| 631 |
+
|
| 632 |
+
if block_name == "res_x":
|
| 633 |
+
block = UNetMidBlock3D(
|
| 634 |
+
dims=dims,
|
| 635 |
+
in_channels=input_channel,
|
| 636 |
+
num_layers=block_params["num_layers"],
|
| 637 |
+
resnet_eps=1e-6,
|
| 638 |
+
resnet_groups=norm_num_groups,
|
| 639 |
+
norm_layer=norm_layer,
|
| 640 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 641 |
+
timestep_conditioning=timestep_conditioning,
|
| 642 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 643 |
+
)
|
| 644 |
+
elif block_name == "attn_res_x":
|
| 645 |
+
block = UNetMidBlock3D(
|
| 646 |
+
dims=dims,
|
| 647 |
+
in_channels=input_channel,
|
| 648 |
+
num_layers=block_params["num_layers"],
|
| 649 |
+
resnet_groups=norm_num_groups,
|
| 650 |
+
norm_layer=norm_layer,
|
| 651 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 652 |
+
timestep_conditioning=timestep_conditioning,
|
| 653 |
+
attention_head_dim=block_params["attention_head_dim"],
|
| 654 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 655 |
+
)
|
| 656 |
+
elif block_name == "res_x_y":
|
| 657 |
+
output_channel = output_channel // block_params.get("multiplier", 2)
|
| 658 |
+
block = ResnetBlock3D(
|
| 659 |
+
dims=dims,
|
| 660 |
+
in_channels=input_channel,
|
| 661 |
+
out_channels=output_channel,
|
| 662 |
+
eps=1e-6,
|
| 663 |
+
groups=norm_num_groups,
|
| 664 |
+
norm_layer=norm_layer,
|
| 665 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 666 |
+
timestep_conditioning=False,
|
| 667 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 668 |
+
)
|
| 669 |
+
elif block_name == "compress_time":
|
| 670 |
+
block = DepthToSpaceUpsample(
|
| 671 |
+
dims=dims,
|
| 672 |
+
in_channels=input_channel,
|
| 673 |
+
stride=(2, 1, 1),
|
| 674 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 675 |
+
)
|
| 676 |
+
elif block_name == "compress_space":
|
| 677 |
+
block = DepthToSpaceUpsample(
|
| 678 |
+
dims=dims,
|
| 679 |
+
in_channels=input_channel,
|
| 680 |
+
stride=(1, 2, 2),
|
| 681 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 682 |
+
)
|
| 683 |
+
elif block_name == "compress_all":
|
| 684 |
+
output_channel = output_channel // block_params.get("multiplier", 1)
|
| 685 |
+
block = DepthToSpaceUpsample(
|
| 686 |
+
dims=dims,
|
| 687 |
+
in_channels=input_channel,
|
| 688 |
+
stride=(2, 2, 2),
|
| 689 |
+
residual=block_params.get("residual", False),
|
| 690 |
+
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
| 691 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 692 |
+
)
|
| 693 |
+
else:
|
| 694 |
+
raise ValueError(f"unknown layer: {block_name}")
|
| 695 |
+
|
| 696 |
+
self.up_blocks.append(block)
|
| 697 |
+
|
| 698 |
+
if norm_layer == "group_norm":
|
| 699 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 700 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 701 |
+
)
|
| 702 |
+
elif norm_layer == "pixel_norm":
|
| 703 |
+
self.conv_norm_out = PixelNorm()
|
| 704 |
+
elif norm_layer == "layer_norm":
|
| 705 |
+
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 706 |
+
|
| 707 |
+
self.conv_act = nn.SiLU()
|
| 708 |
+
self.conv_out = make_conv_nd(
|
| 709 |
+
dims,
|
| 710 |
+
output_channel,
|
| 711 |
+
out_channels,
|
| 712 |
+
3,
|
| 713 |
+
padding=1,
|
| 714 |
+
causal=True,
|
| 715 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
self.gradient_checkpointing = False
|
| 719 |
+
|
| 720 |
+
self.timestep_conditioning = timestep_conditioning
|
| 721 |
+
|
| 722 |
+
if timestep_conditioning:
|
| 723 |
+
self.timestep_scale_multiplier = nn.Parameter(
|
| 724 |
+
torch.tensor(1000.0, dtype=torch.float32)
|
| 725 |
+
)
|
| 726 |
+
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 727 |
+
output_channel * 2, 0
|
| 728 |
+
)
|
| 729 |
+
self.last_scale_shift_table = nn.Parameter(
|
| 730 |
+
torch.randn(2, output_channel) / output_channel**0.5
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
def forward(
|
| 734 |
+
self,
|
| 735 |
+
sample: torch.FloatTensor,
|
| 736 |
+
target_shape,
|
| 737 |
+
timestep: Optional[torch.Tensor] = None,
|
| 738 |
+
) -> torch.FloatTensor:
|
| 739 |
+
r"""The forward method of the `Decoder` class."""
|
| 740 |
+
assert target_shape is not None, "target_shape must be provided"
|
| 741 |
+
batch_size = sample.shape[0]
|
| 742 |
+
|
| 743 |
+
sample = self.conv_in(sample, causal=self.causal)
|
| 744 |
+
|
| 745 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 746 |
+
|
| 747 |
+
checkpoint_fn = (
|
| 748 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 749 |
+
if self.gradient_checkpointing and self.training
|
| 750 |
+
else lambda x: x
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
sample = sample.to(upscale_dtype)
|
| 754 |
+
|
| 755 |
+
if self.timestep_conditioning:
|
| 756 |
+
assert (
|
| 757 |
+
timestep is not None
|
| 758 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 759 |
+
scaled_timestep = timestep * self.timestep_scale_multiplier
|
| 760 |
+
|
| 761 |
+
for up_block in self.up_blocks:
|
| 762 |
+
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
| 763 |
+
sample = checkpoint_fn(up_block)(
|
| 764 |
+
sample, causal=self.causal, timestep=scaled_timestep
|
| 765 |
+
)
|
| 766 |
+
else:
|
| 767 |
+
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
| 768 |
+
|
| 769 |
+
sample = self.conv_norm_out(sample)
|
| 770 |
+
|
| 771 |
+
if self.timestep_conditioning:
|
| 772 |
+
embedded_timestep = self.last_time_embedder(
|
| 773 |
+
timestep=scaled_timestep.flatten(),
|
| 774 |
+
resolution=None,
|
| 775 |
+
aspect_ratio=None,
|
| 776 |
+
batch_size=sample.shape[0],
|
| 777 |
+
hidden_dtype=sample.dtype,
|
| 778 |
+
)
|
| 779 |
+
embedded_timestep = embedded_timestep.view(
|
| 780 |
+
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
| 781 |
+
)
|
| 782 |
+
ada_values = self.last_scale_shift_table[
|
| 783 |
+
None, ..., None, None, None
|
| 784 |
+
] + embedded_timestep.reshape(
|
| 785 |
+
batch_size,
|
| 786 |
+
2,
|
| 787 |
+
-1,
|
| 788 |
+
embedded_timestep.shape[-3],
|
| 789 |
+
embedded_timestep.shape[-2],
|
| 790 |
+
embedded_timestep.shape[-1],
|
| 791 |
+
)
|
| 792 |
+
shift, scale = ada_values.unbind(dim=1)
|
| 793 |
+
sample = sample * (1 + scale) + shift
|
| 794 |
+
|
| 795 |
+
sample = self.conv_act(sample)
|
| 796 |
+
sample = self.conv_out(sample, causal=self.causal)
|
| 797 |
+
|
| 798 |
+
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 799 |
+
|
| 800 |
+
return sample
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
class UNetMidBlock3D(nn.Module):
|
| 804 |
+
"""
|
| 805 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
in_channels (`int`): The number of input channels.
|
| 809 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 810 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 811 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 812 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 813 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 814 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 815 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 816 |
+
inject_noise (`bool`, *optional*, defaults to `False`):
|
| 817 |
+
Whether to inject noise into the hidden states.
|
| 818 |
+
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
| 819 |
+
Whether to condition the hidden states on the timestep.
|
| 820 |
+
attention_head_dim (`int`, *optional*, defaults to -1):
|
| 821 |
+
The dimension of the attention head. If -1, no attention is used.
|
| 822 |
+
|
| 823 |
+
Returns:
|
| 824 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 825 |
+
in_channels, height, width)`.
|
| 826 |
+
|
| 827 |
+
"""
|
| 828 |
+
|
| 829 |
+
def __init__(
|
| 830 |
+
self,
|
| 831 |
+
dims: Union[int, Tuple[int, int]],
|
| 832 |
+
in_channels: int,
|
| 833 |
+
dropout: float = 0.0,
|
| 834 |
+
num_layers: int = 1,
|
| 835 |
+
resnet_eps: float = 1e-6,
|
| 836 |
+
resnet_groups: int = 32,
|
| 837 |
+
norm_layer: str = "group_norm",
|
| 838 |
+
inject_noise: bool = False,
|
| 839 |
+
timestep_conditioning: bool = False,
|
| 840 |
+
attention_head_dim: int = -1,
|
| 841 |
+
spatial_padding_mode: str = "zeros",
|
| 842 |
+
):
|
| 843 |
+
super().__init__()
|
| 844 |
+
resnet_groups = (
|
| 845 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 846 |
+
)
|
| 847 |
+
self.timestep_conditioning = timestep_conditioning
|
| 848 |
+
|
| 849 |
+
if timestep_conditioning:
|
| 850 |
+
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 851 |
+
in_channels * 4, 0
|
| 852 |
+
)
|
| 853 |
+
|
| 854 |
+
self.res_blocks = nn.ModuleList(
|
| 855 |
+
[
|
| 856 |
+
ResnetBlock3D(
|
| 857 |
+
dims=dims,
|
| 858 |
+
in_channels=in_channels,
|
| 859 |
+
out_channels=in_channels,
|
| 860 |
+
eps=resnet_eps,
|
| 861 |
+
groups=resnet_groups,
|
| 862 |
+
dropout=dropout,
|
| 863 |
+
norm_layer=norm_layer,
|
| 864 |
+
inject_noise=inject_noise,
|
| 865 |
+
timestep_conditioning=timestep_conditioning,
|
| 866 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 867 |
+
)
|
| 868 |
+
for _ in range(num_layers)
|
| 869 |
+
]
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
self.attention_blocks = None
|
| 873 |
+
|
| 874 |
+
if attention_head_dim > 0:
|
| 875 |
+
if attention_head_dim > in_channels:
|
| 876 |
+
raise ValueError(
|
| 877 |
+
"attention_head_dim must be less than or equal to in_channels"
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
self.attention_blocks = nn.ModuleList(
|
| 881 |
+
[
|
| 882 |
+
Attention(
|
| 883 |
+
query_dim=in_channels,
|
| 884 |
+
heads=in_channels // attention_head_dim,
|
| 885 |
+
dim_head=attention_head_dim,
|
| 886 |
+
bias=True,
|
| 887 |
+
out_bias=True,
|
| 888 |
+
qk_norm="rms_norm",
|
| 889 |
+
residual_connection=True,
|
| 890 |
+
)
|
| 891 |
+
for _ in range(num_layers)
|
| 892 |
+
]
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
def forward(
|
| 896 |
+
self,
|
| 897 |
+
hidden_states: torch.FloatTensor,
|
| 898 |
+
causal: bool = True,
|
| 899 |
+
timestep: Optional[torch.Tensor] = None,
|
| 900 |
+
) -> torch.FloatTensor:
|
| 901 |
+
timestep_embed = None
|
| 902 |
+
if self.timestep_conditioning:
|
| 903 |
+
assert (
|
| 904 |
+
timestep is not None
|
| 905 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 906 |
+
batch_size = hidden_states.shape[0]
|
| 907 |
+
timestep_embed = self.time_embedder(
|
| 908 |
+
timestep=timestep.flatten(),
|
| 909 |
+
resolution=None,
|
| 910 |
+
aspect_ratio=None,
|
| 911 |
+
batch_size=batch_size,
|
| 912 |
+
hidden_dtype=hidden_states.dtype,
|
| 913 |
+
)
|
| 914 |
+
timestep_embed = timestep_embed.view(
|
| 915 |
+
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
if self.attention_blocks:
|
| 919 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
| 920 |
+
hidden_states = resnet(
|
| 921 |
+
hidden_states, causal=causal, timestep=timestep_embed
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
| 925 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
| 926 |
+
hidden_states = hidden_states.view(
|
| 927 |
+
batch_size, channel, frames * height * width
|
| 928 |
+
).transpose(1, 2)
|
| 929 |
+
|
| 930 |
+
if attention.use_tpu_flash_attention:
|
| 931 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
| 932 |
+
seq_len = hidden_states.shape[1]
|
| 933 |
+
block_k_major = 512
|
| 934 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
| 935 |
+
if pad_len > 0:
|
| 936 |
+
hidden_states = F.pad(
|
| 937 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
| 941 |
+
mask = torch.ones(
|
| 942 |
+
(hidden_states.shape[0], seq_len),
|
| 943 |
+
device=hidden_states.device,
|
| 944 |
+
dtype=hidden_states.dtype,
|
| 945 |
+
)
|
| 946 |
+
if pad_len > 0:
|
| 947 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
| 948 |
+
|
| 949 |
+
hidden_states = attention(
|
| 950 |
+
hidden_states,
|
| 951 |
+
attention_mask=(
|
| 952 |
+
None if not attention.use_tpu_flash_attention else mask
|
| 953 |
+
),
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
if attention.use_tpu_flash_attention:
|
| 957 |
+
# Remove the padding
|
| 958 |
+
if pad_len > 0:
|
| 959 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
| 960 |
+
|
| 961 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
| 962 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 963 |
+
batch_size, channel, frames, height, width
|
| 964 |
+
)
|
| 965 |
+
else:
|
| 966 |
+
for resnet in self.res_blocks:
|
| 967 |
+
hidden_states = resnet(
|
| 968 |
+
hidden_states, causal=causal, timestep=timestep_embed
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
return hidden_states
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
class SpaceToDepthDownsample(nn.Module):
|
| 975 |
+
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
| 976 |
+
super().__init__()
|
| 977 |
+
self.stride = stride
|
| 978 |
+
self.group_size = in_channels * np.prod(stride) // out_channels
|
| 979 |
+
self.conv = make_conv_nd(
|
| 980 |
+
dims=dims,
|
| 981 |
+
in_channels=in_channels,
|
| 982 |
+
out_channels=out_channels // np.prod(stride),
|
| 983 |
+
kernel_size=3,
|
| 984 |
+
stride=1,
|
| 985 |
+
causal=True,
|
| 986 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 987 |
+
)
|
| 988 |
+
|
| 989 |
+
def forward(self, x, causal: bool = True):
|
| 990 |
+
if self.stride[0] == 2:
|
| 991 |
+
x = torch.cat(
|
| 992 |
+
[x[:, :, :1, :, :], x], dim=2
|
| 993 |
+
) # duplicate first frames for padding
|
| 994 |
+
|
| 995 |
+
# skip connection
|
| 996 |
+
x_in = rearrange(
|
| 997 |
+
x,
|
| 998 |
+
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
| 999 |
+
p1=self.stride[0],
|
| 1000 |
+
p2=self.stride[1],
|
| 1001 |
+
p3=self.stride[2],
|
| 1002 |
+
)
|
| 1003 |
+
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
| 1004 |
+
x_in = x_in.mean(dim=2)
|
| 1005 |
+
|
| 1006 |
+
# conv
|
| 1007 |
+
x = self.conv(x, causal=causal)
|
| 1008 |
+
x = rearrange(
|
| 1009 |
+
x,
|
| 1010 |
+
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
| 1011 |
+
p1=self.stride[0],
|
| 1012 |
+
p2=self.stride[1],
|
| 1013 |
+
p3=self.stride[2],
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
x = x + x_in
|
| 1017 |
+
|
| 1018 |
+
return x
|
| 1019 |
+
|
| 1020 |
+
|
| 1021 |
+
class DepthToSpaceUpsample(nn.Module):
|
| 1022 |
+
def __init__(
|
| 1023 |
+
self,
|
| 1024 |
+
dims,
|
| 1025 |
+
in_channels,
|
| 1026 |
+
stride,
|
| 1027 |
+
residual=False,
|
| 1028 |
+
out_channels_reduction_factor=1,
|
| 1029 |
+
spatial_padding_mode="zeros",
|
| 1030 |
+
):
|
| 1031 |
+
super().__init__()
|
| 1032 |
+
self.stride = stride
|
| 1033 |
+
self.out_channels = (
|
| 1034 |
+
np.prod(stride) * in_channels // out_channels_reduction_factor
|
| 1035 |
+
)
|
| 1036 |
+
self.conv = make_conv_nd(
|
| 1037 |
+
dims=dims,
|
| 1038 |
+
in_channels=in_channels,
|
| 1039 |
+
out_channels=self.out_channels,
|
| 1040 |
+
kernel_size=3,
|
| 1041 |
+
stride=1,
|
| 1042 |
+
causal=True,
|
| 1043 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1044 |
+
)
|
| 1045 |
+
self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
|
| 1046 |
+
self.residual = residual
|
| 1047 |
+
self.out_channels_reduction_factor = out_channels_reduction_factor
|
| 1048 |
+
|
| 1049 |
+
def forward(self, x, causal: bool = True):
|
| 1050 |
+
if self.residual:
|
| 1051 |
+
# Reshape and duplicate the input to match the output shape
|
| 1052 |
+
x_in = self.pixel_shuffle(x)
|
| 1053 |
+
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
|
| 1054 |
+
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
| 1055 |
+
if self.stride[0] == 2:
|
| 1056 |
+
x_in = x_in[:, :, 1:, :, :]
|
| 1057 |
+
x = self.conv(x, causal=causal)
|
| 1058 |
+
x = self.pixel_shuffle(x)
|
| 1059 |
+
if self.stride[0] == 2:
|
| 1060 |
+
x = x[:, :, 1:, :, :]
|
| 1061 |
+
if self.residual:
|
| 1062 |
+
x = x + x_in
|
| 1063 |
+
return x
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
class LayerNorm(nn.Module):
|
| 1067 |
+
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
| 1068 |
+
super().__init__()
|
| 1069 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 1070 |
+
|
| 1071 |
+
def forward(self, x):
|
| 1072 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
| 1073 |
+
x = self.norm(x)
|
| 1074 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
| 1075 |
+
return x
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
class ResnetBlock3D(nn.Module):
|
| 1079 |
+
r"""
|
| 1080 |
+
A Resnet block.
|
| 1081 |
+
|
| 1082 |
+
Parameters:
|
| 1083 |
+
in_channels (`int`): The number of channels in the input.
|
| 1084 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 1085 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 1086 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 1087 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 1088 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 1089 |
+
"""
|
| 1090 |
+
|
| 1091 |
+
def __init__(
|
| 1092 |
+
self,
|
| 1093 |
+
dims: Union[int, Tuple[int, int]],
|
| 1094 |
+
in_channels: int,
|
| 1095 |
+
out_channels: Optional[int] = None,
|
| 1096 |
+
dropout: float = 0.0,
|
| 1097 |
+
groups: int = 32,
|
| 1098 |
+
eps: float = 1e-6,
|
| 1099 |
+
norm_layer: str = "group_norm",
|
| 1100 |
+
inject_noise: bool = False,
|
| 1101 |
+
timestep_conditioning: bool = False,
|
| 1102 |
+
spatial_padding_mode: str = "zeros",
|
| 1103 |
+
):
|
| 1104 |
+
super().__init__()
|
| 1105 |
+
self.in_channels = in_channels
|
| 1106 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 1107 |
+
self.out_channels = out_channels
|
| 1108 |
+
self.inject_noise = inject_noise
|
| 1109 |
+
|
| 1110 |
+
if norm_layer == "group_norm":
|
| 1111 |
+
self.norm1 = nn.GroupNorm(
|
| 1112 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 1113 |
+
)
|
| 1114 |
+
elif norm_layer == "pixel_norm":
|
| 1115 |
+
self.norm1 = PixelNorm()
|
| 1116 |
+
elif norm_layer == "layer_norm":
|
| 1117 |
+
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
| 1118 |
+
|
| 1119 |
+
self.non_linearity = nn.SiLU()
|
| 1120 |
+
|
| 1121 |
+
self.conv1 = make_conv_nd(
|
| 1122 |
+
dims,
|
| 1123 |
+
in_channels,
|
| 1124 |
+
out_channels,
|
| 1125 |
+
kernel_size=3,
|
| 1126 |
+
stride=1,
|
| 1127 |
+
padding=1,
|
| 1128 |
+
causal=True,
|
| 1129 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
if inject_noise:
|
| 1133 |
+
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 1134 |
+
|
| 1135 |
+
if norm_layer == "group_norm":
|
| 1136 |
+
self.norm2 = nn.GroupNorm(
|
| 1137 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 1138 |
+
)
|
| 1139 |
+
elif norm_layer == "pixel_norm":
|
| 1140 |
+
self.norm2 = PixelNorm()
|
| 1141 |
+
elif norm_layer == "layer_norm":
|
| 1142 |
+
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
| 1143 |
+
|
| 1144 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 1145 |
+
|
| 1146 |
+
self.conv2 = make_conv_nd(
|
| 1147 |
+
dims,
|
| 1148 |
+
out_channels,
|
| 1149 |
+
out_channels,
|
| 1150 |
+
kernel_size=3,
|
| 1151 |
+
stride=1,
|
| 1152 |
+
padding=1,
|
| 1153 |
+
causal=True,
|
| 1154 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
if inject_noise:
|
| 1158 |
+
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 1159 |
+
|
| 1160 |
+
self.conv_shortcut = (
|
| 1161 |
+
make_linear_nd(
|
| 1162 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 1163 |
+
)
|
| 1164 |
+
if in_channels != out_channels
|
| 1165 |
+
else nn.Identity()
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
self.norm3 = (
|
| 1169 |
+
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
| 1170 |
+
if in_channels != out_channels
|
| 1171 |
+
else nn.Identity()
|
| 1172 |
+
)
|
| 1173 |
+
|
| 1174 |
+
self.timestep_conditioning = timestep_conditioning
|
| 1175 |
+
|
| 1176 |
+
if timestep_conditioning:
|
| 1177 |
+
self.scale_shift_table = nn.Parameter(
|
| 1178 |
+
torch.randn(4, in_channels) / in_channels**0.5
|
| 1179 |
+
)
|
| 1180 |
+
|
| 1181 |
+
def _feed_spatial_noise(
|
| 1182 |
+
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
| 1183 |
+
) -> torch.FloatTensor:
|
| 1184 |
+
spatial_shape = hidden_states.shape[-2:]
|
| 1185 |
+
device = hidden_states.device
|
| 1186 |
+
dtype = hidden_states.dtype
|
| 1187 |
+
|
| 1188 |
+
# similar to the "explicit noise inputs" method in style-gan
|
| 1189 |
+
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
| 1190 |
+
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
| 1191 |
+
hidden_states = hidden_states + scaled_noise
|
| 1192 |
+
|
| 1193 |
+
return hidden_states
|
| 1194 |
+
|
| 1195 |
+
def forward(
|
| 1196 |
+
self,
|
| 1197 |
+
input_tensor: torch.FloatTensor,
|
| 1198 |
+
causal: bool = True,
|
| 1199 |
+
timestep: Optional[torch.Tensor] = None,
|
| 1200 |
+
) -> torch.FloatTensor:
|
| 1201 |
+
hidden_states = input_tensor
|
| 1202 |
+
batch_size = hidden_states.shape[0]
|
| 1203 |
+
|
| 1204 |
+
hidden_states = self.norm1(hidden_states)
|
| 1205 |
+
if self.timestep_conditioning:
|
| 1206 |
+
assert (
|
| 1207 |
+
timestep is not None
|
| 1208 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 1209 |
+
ada_values = self.scale_shift_table[
|
| 1210 |
+
None, ..., None, None, None
|
| 1211 |
+
] + timestep.reshape(
|
| 1212 |
+
batch_size,
|
| 1213 |
+
4,
|
| 1214 |
+
-1,
|
| 1215 |
+
timestep.shape[-3],
|
| 1216 |
+
timestep.shape[-2],
|
| 1217 |
+
timestep.shape[-1],
|
| 1218 |
+
)
|
| 1219 |
+
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
| 1220 |
+
|
| 1221 |
+
hidden_states = hidden_states * (1 + scale1) + shift1
|
| 1222 |
+
|
| 1223 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 1224 |
+
|
| 1225 |
+
hidden_states = self.conv1(hidden_states, causal=causal)
|
| 1226 |
+
|
| 1227 |
+
if self.inject_noise:
|
| 1228 |
+
hidden_states = self._feed_spatial_noise(
|
| 1229 |
+
hidden_states, self.per_channel_scale1
|
| 1230 |
+
)
|
| 1231 |
+
|
| 1232 |
+
hidden_states = self.norm2(hidden_states)
|
| 1233 |
+
|
| 1234 |
+
if self.timestep_conditioning:
|
| 1235 |
+
hidden_states = hidden_states * (1 + scale2) + shift2
|
| 1236 |
+
|
| 1237 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 1238 |
+
|
| 1239 |
+
hidden_states = self.dropout(hidden_states)
|
| 1240 |
+
|
| 1241 |
+
hidden_states = self.conv2(hidden_states, causal=causal)
|
| 1242 |
+
|
| 1243 |
+
if self.inject_noise:
|
| 1244 |
+
hidden_states = self._feed_spatial_noise(
|
| 1245 |
+
hidden_states, self.per_channel_scale2
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
input_tensor = self.norm3(input_tensor)
|
| 1249 |
+
|
| 1250 |
+
batch_size = input_tensor.shape[0]
|
| 1251 |
+
|
| 1252 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 1253 |
+
|
| 1254 |
+
output_tensor = input_tensor + hidden_states
|
| 1255 |
+
|
| 1256 |
+
return output_tensor
|
| 1257 |
+
|
| 1258 |
+
|
| 1259 |
+
def patchify(x, patch_size_hw, patch_size_t=1):
|
| 1260 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 1261 |
+
return x
|
| 1262 |
+
if x.dim() == 4:
|
| 1263 |
+
x = rearrange(
|
| 1264 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 1265 |
+
)
|
| 1266 |
+
elif x.dim() == 5:
|
| 1267 |
+
x = rearrange(
|
| 1268 |
+
x,
|
| 1269 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 1270 |
+
p=patch_size_t,
|
| 1271 |
+
q=patch_size_hw,
|
| 1272 |
+
r=patch_size_hw,
|
| 1273 |
+
)
|
| 1274 |
+
else:
|
| 1275 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 1276 |
+
|
| 1277 |
+
return x
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
| 1281 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 1282 |
+
return x
|
| 1283 |
+
|
| 1284 |
+
if x.dim() == 4:
|
| 1285 |
+
x = rearrange(
|
| 1286 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 1287 |
+
)
|
| 1288 |
+
elif x.dim() == 5:
|
| 1289 |
+
x = rearrange(
|
| 1290 |
+
x,
|
| 1291 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 1292 |
+
p=patch_size_t,
|
| 1293 |
+
q=patch_size_hw,
|
| 1294 |
+
r=patch_size_hw,
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
return x
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
def create_video_autoencoder_demo_config(
|
| 1301 |
+
latent_channels: int = 64,
|
| 1302 |
+
):
|
| 1303 |
+
encoder_blocks = [
|
| 1304 |
+
("res_x", {"num_layers": 2}),
|
| 1305 |
+
("compress_space_res", {"multiplier": 2}),
|
| 1306 |
+
("compress_time_res", {"multiplier": 2}),
|
| 1307 |
+
("compress_all_res", {"multiplier": 2}),
|
| 1308 |
+
("compress_all_res", {"multiplier": 2}),
|
| 1309 |
+
("res_x", {"num_layers": 1}),
|
| 1310 |
+
]
|
| 1311 |
+
decoder_blocks = [
|
| 1312 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1313 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1314 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1315 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1316 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1317 |
+
]
|
| 1318 |
+
return {
|
| 1319 |
+
"_class_name": "CausalVideoAutoencoder",
|
| 1320 |
+
"dims": 3,
|
| 1321 |
+
"encoder_blocks": encoder_blocks,
|
| 1322 |
+
"decoder_blocks": decoder_blocks,
|
| 1323 |
+
"latent_channels": latent_channels,
|
| 1324 |
+
"norm_layer": "pixel_norm",
|
| 1325 |
+
"patch_size": 4,
|
| 1326 |
+
"latent_log_var": "uniform",
|
| 1327 |
+
"use_quant_conv": False,
|
| 1328 |
+
"causal_decoder": False,
|
| 1329 |
+
"timestep_conditioning": True,
|
| 1330 |
+
"spatial_padding_mode": "replicate",
|
| 1331 |
+
}
|
| 1332 |
+
|
| 1333 |
+
|
| 1334 |
+
def test_vae_patchify_unpatchify():
|
| 1335 |
+
import torch
|
| 1336 |
+
|
| 1337 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
| 1338 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
| 1339 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
| 1340 |
+
assert torch.allclose(x, x_unpatched)
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
def demo_video_autoencoder_forward_backward():
|
| 1344 |
+
# Configuration for the VideoAutoencoder
|
| 1345 |
+
config = create_video_autoencoder_demo_config()
|
| 1346 |
+
|
| 1347 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
| 1348 |
+
video_autoencoder = CausalVideoAutoencoder.from_config(config)
|
| 1349 |
+
|
| 1350 |
+
print(video_autoencoder)
|
| 1351 |
+
video_autoencoder.eval()
|
| 1352 |
+
# Print the total number of parameters in the video autoencoder
|
| 1353 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
| 1354 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
| 1355 |
+
|
| 1356 |
+
# Create a mock input tensor simulating a batch of videos
|
| 1357 |
+
# Shape: (batch_size, channels, depth, height, width)
|
| 1358 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
| 1359 |
+
input_videos = torch.randn(2, 3, 17, 64, 64)
|
| 1360 |
+
|
| 1361 |
+
# Forward pass: encode and decode the input videos
|
| 1362 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 1363 |
+
print(f"input shape={input_videos.shape}")
|
| 1364 |
+
print(f"latent shape={latent.shape}")
|
| 1365 |
+
|
| 1366 |
+
timestep = torch.ones(input_videos.shape[0]) * 0.1
|
| 1367 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 1368 |
+
latent, target_shape=input_videos.shape, timestep=timestep
|
| 1369 |
+
).sample
|
| 1370 |
+
|
| 1371 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 1372 |
+
|
| 1373 |
+
# Validate that single image gets treated the same way as first frame
|
| 1374 |
+
input_image = input_videos[:, :, :1, :, :]
|
| 1375 |
+
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
| 1376 |
+
_ = video_autoencoder.decode(
|
| 1377 |
+
image_latent, target_shape=image_latent.shape, timestep=timestep
|
| 1378 |
+
).sample
|
| 1379 |
+
|
| 1380 |
+
first_frame_latent = latent[:, :, :1, :, :]
|
| 1381 |
+
|
| 1382 |
+
assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
|
| 1383 |
+
# assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
|
| 1384 |
+
# assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
|
| 1385 |
+
# assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
|
| 1386 |
+
|
| 1387 |
+
# Calculate the loss (e.g., mean squared error)
|
| 1388 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
| 1389 |
+
|
| 1390 |
+
# Perform backward pass
|
| 1391 |
+
loss.backward()
|
| 1392 |
+
|
| 1393 |
+
print(f"Demo completed with loss: {loss.item()}")
|
| 1394 |
+
|
| 1395 |
+
|
| 1396 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
| 1397 |
+
if __name__ == "__main__":
|
| 1398 |
+
demo_video_autoencoder_forward_backward()
|
ltx_video/models/autoencoders/conv_nd_factory.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
|
| 6 |
+
from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def make_conv_nd(
|
| 10 |
+
dims: Union[int, Tuple[int, int]],
|
| 11 |
+
in_channels: int,
|
| 12 |
+
out_channels: int,
|
| 13 |
+
kernel_size: int,
|
| 14 |
+
stride=1,
|
| 15 |
+
padding=0,
|
| 16 |
+
dilation=1,
|
| 17 |
+
groups=1,
|
| 18 |
+
bias=True,
|
| 19 |
+
causal=False,
|
| 20 |
+
spatial_padding_mode="zeros",
|
| 21 |
+
temporal_padding_mode="zeros",
|
| 22 |
+
):
|
| 23 |
+
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
| 24 |
+
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
| 25 |
+
if dims == 2:
|
| 26 |
+
return torch.nn.Conv2d(
|
| 27 |
+
in_channels=in_channels,
|
| 28 |
+
out_channels=out_channels,
|
| 29 |
+
kernel_size=kernel_size,
|
| 30 |
+
stride=stride,
|
| 31 |
+
padding=padding,
|
| 32 |
+
dilation=dilation,
|
| 33 |
+
groups=groups,
|
| 34 |
+
bias=bias,
|
| 35 |
+
padding_mode=spatial_padding_mode,
|
| 36 |
+
)
|
| 37 |
+
elif dims == 3:
|
| 38 |
+
if causal:
|
| 39 |
+
return CausalConv3d(
|
| 40 |
+
in_channels=in_channels,
|
| 41 |
+
out_channels=out_channels,
|
| 42 |
+
kernel_size=kernel_size,
|
| 43 |
+
stride=stride,
|
| 44 |
+
padding=padding,
|
| 45 |
+
dilation=dilation,
|
| 46 |
+
groups=groups,
|
| 47 |
+
bias=bias,
|
| 48 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 49 |
+
)
|
| 50 |
+
return torch.nn.Conv3d(
|
| 51 |
+
in_channels=in_channels,
|
| 52 |
+
out_channels=out_channels,
|
| 53 |
+
kernel_size=kernel_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=padding,
|
| 56 |
+
dilation=dilation,
|
| 57 |
+
groups=groups,
|
| 58 |
+
bias=bias,
|
| 59 |
+
padding_mode=spatial_padding_mode,
|
| 60 |
+
)
|
| 61 |
+
elif dims == (2, 1):
|
| 62 |
+
return DualConv3d(
|
| 63 |
+
in_channels=in_channels,
|
| 64 |
+
out_channels=out_channels,
|
| 65 |
+
kernel_size=kernel_size,
|
| 66 |
+
stride=stride,
|
| 67 |
+
padding=padding,
|
| 68 |
+
bias=bias,
|
| 69 |
+
padding_mode=spatial_padding_mode,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_linear_nd(
|
| 76 |
+
dims: int,
|
| 77 |
+
in_channels: int,
|
| 78 |
+
out_channels: int,
|
| 79 |
+
bias=True,
|
| 80 |
+
):
|
| 81 |
+
if dims == 2:
|
| 82 |
+
return torch.nn.Conv2d(
|
| 83 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 84 |
+
)
|
| 85 |
+
elif dims == 3 or dims == (2, 1):
|
| 86 |
+
return torch.nn.Conv3d(
|
| 87 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
ltx_video/models/autoencoders/dual_conv3d.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DualConv3d(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_channels,
|
| 14 |
+
out_channels,
|
| 15 |
+
kernel_size,
|
| 16 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 17 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 18 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 19 |
+
groups=1,
|
| 20 |
+
bias=True,
|
| 21 |
+
padding_mode="zeros",
|
| 22 |
+
):
|
| 23 |
+
super(DualConv3d, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.in_channels = in_channels
|
| 26 |
+
self.out_channels = out_channels
|
| 27 |
+
self.padding_mode = padding_mode
|
| 28 |
+
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
| 29 |
+
if isinstance(kernel_size, int):
|
| 30 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 31 |
+
if kernel_size == (1, 1, 1):
|
| 32 |
+
raise ValueError(
|
| 33 |
+
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
| 34 |
+
)
|
| 35 |
+
if isinstance(stride, int):
|
| 36 |
+
stride = (stride, stride, stride)
|
| 37 |
+
if isinstance(padding, int):
|
| 38 |
+
padding = (padding, padding, padding)
|
| 39 |
+
if isinstance(dilation, int):
|
| 40 |
+
dilation = (dilation, dilation, dilation)
|
| 41 |
+
|
| 42 |
+
# Set parameters for convolutions
|
| 43 |
+
self.groups = groups
|
| 44 |
+
self.bias = bias
|
| 45 |
+
|
| 46 |
+
# Define the size of the channels after the first convolution
|
| 47 |
+
intermediate_channels = (
|
| 48 |
+
out_channels if in_channels < out_channels else in_channels
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Define parameters for the first convolution
|
| 52 |
+
self.weight1 = nn.Parameter(
|
| 53 |
+
torch.Tensor(
|
| 54 |
+
intermediate_channels,
|
| 55 |
+
in_channels // groups,
|
| 56 |
+
1,
|
| 57 |
+
kernel_size[1],
|
| 58 |
+
kernel_size[2],
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
self.stride1 = (1, stride[1], stride[2])
|
| 62 |
+
self.padding1 = (0, padding[1], padding[2])
|
| 63 |
+
self.dilation1 = (1, dilation[1], dilation[2])
|
| 64 |
+
if bias:
|
| 65 |
+
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
| 66 |
+
else:
|
| 67 |
+
self.register_parameter("bias1", None)
|
| 68 |
+
|
| 69 |
+
# Define parameters for the second convolution
|
| 70 |
+
self.weight2 = nn.Parameter(
|
| 71 |
+
torch.Tensor(
|
| 72 |
+
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
self.stride2 = (stride[0], 1, 1)
|
| 76 |
+
self.padding2 = (padding[0], 0, 0)
|
| 77 |
+
self.dilation2 = (dilation[0], 1, 1)
|
| 78 |
+
if bias:
|
| 79 |
+
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
| 80 |
+
else:
|
| 81 |
+
self.register_parameter("bias2", None)
|
| 82 |
+
|
| 83 |
+
# Initialize weights and biases
|
| 84 |
+
self.reset_parameters()
|
| 85 |
+
|
| 86 |
+
def reset_parameters(self):
|
| 87 |
+
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
| 88 |
+
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
| 89 |
+
if self.bias:
|
| 90 |
+
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
| 91 |
+
bound1 = 1 / math.sqrt(fan_in1)
|
| 92 |
+
nn.init.uniform_(self.bias1, -bound1, bound1)
|
| 93 |
+
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
| 94 |
+
bound2 = 1 / math.sqrt(fan_in2)
|
| 95 |
+
nn.init.uniform_(self.bias2, -bound2, bound2)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
| 98 |
+
if use_conv3d:
|
| 99 |
+
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
| 100 |
+
else:
|
| 101 |
+
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
| 102 |
+
|
| 103 |
+
def forward_with_3d(self, x, skip_time_conv):
|
| 104 |
+
# First convolution
|
| 105 |
+
x = F.conv3d(
|
| 106 |
+
x,
|
| 107 |
+
self.weight1,
|
| 108 |
+
self.bias1,
|
| 109 |
+
self.stride1,
|
| 110 |
+
self.padding1,
|
| 111 |
+
self.dilation1,
|
| 112 |
+
self.groups,
|
| 113 |
+
padding_mode=self.padding_mode,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if skip_time_conv:
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
# Second convolution
|
| 120 |
+
x = F.conv3d(
|
| 121 |
+
x,
|
| 122 |
+
self.weight2,
|
| 123 |
+
self.bias2,
|
| 124 |
+
self.stride2,
|
| 125 |
+
self.padding2,
|
| 126 |
+
self.dilation2,
|
| 127 |
+
self.groups,
|
| 128 |
+
padding_mode=self.padding_mode,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
def forward_with_2d(self, x, skip_time_conv):
|
| 134 |
+
b, c, d, h, w = x.shape
|
| 135 |
+
|
| 136 |
+
# First 2D convolution
|
| 137 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 138 |
+
# Squeeze the depth dimension out of weight1 since it's 1
|
| 139 |
+
weight1 = self.weight1.squeeze(2)
|
| 140 |
+
# Select stride, padding, and dilation for the 2D convolution
|
| 141 |
+
stride1 = (self.stride1[1], self.stride1[2])
|
| 142 |
+
padding1 = (self.padding1[1], self.padding1[2])
|
| 143 |
+
dilation1 = (self.dilation1[1], self.dilation1[2])
|
| 144 |
+
x = F.conv2d(
|
| 145 |
+
x,
|
| 146 |
+
weight1,
|
| 147 |
+
self.bias1,
|
| 148 |
+
stride1,
|
| 149 |
+
padding1,
|
| 150 |
+
dilation1,
|
| 151 |
+
self.groups,
|
| 152 |
+
padding_mode=self.padding_mode,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
_, _, h, w = x.shape
|
| 156 |
+
|
| 157 |
+
if skip_time_conv:
|
| 158 |
+
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
| 162 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
| 163 |
+
|
| 164 |
+
# Reshape weight2 to match the expected dimensions for conv1d
|
| 165 |
+
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
| 166 |
+
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
| 167 |
+
stride2 = self.stride2[0]
|
| 168 |
+
padding2 = self.padding2[0]
|
| 169 |
+
dilation2 = self.dilation2[0]
|
| 170 |
+
x = F.conv1d(
|
| 171 |
+
x,
|
| 172 |
+
weight2,
|
| 173 |
+
self.bias2,
|
| 174 |
+
stride2,
|
| 175 |
+
padding2,
|
| 176 |
+
dilation2,
|
| 177 |
+
self.groups,
|
| 178 |
+
padding_mode=self.padding_mode,
|
| 179 |
+
)
|
| 180 |
+
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
| 181 |
+
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def weight(self):
|
| 186 |
+
return self.weight2
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def test_dual_conv3d_consistency():
|
| 190 |
+
# Initialize parameters
|
| 191 |
+
in_channels = 3
|
| 192 |
+
out_channels = 5
|
| 193 |
+
kernel_size = (3, 3, 3)
|
| 194 |
+
stride = (2, 2, 2)
|
| 195 |
+
padding = (1, 1, 1)
|
| 196 |
+
|
| 197 |
+
# Create an instance of the DualConv3d class
|
| 198 |
+
dual_conv3d = DualConv3d(
|
| 199 |
+
in_channels=in_channels,
|
| 200 |
+
out_channels=out_channels,
|
| 201 |
+
kernel_size=kernel_size,
|
| 202 |
+
stride=stride,
|
| 203 |
+
padding=padding,
|
| 204 |
+
bias=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Example input tensor
|
| 208 |
+
test_input = torch.randn(1, 3, 10, 10, 10)
|
| 209 |
+
|
| 210 |
+
# Perform forward passes with both 3D and 2D settings
|
| 211 |
+
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
| 212 |
+
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
| 213 |
+
|
| 214 |
+
# Assert that the outputs from both methods are sufficiently close
|
| 215 |
+
assert torch.allclose(
|
| 216 |
+
output_conv3d, output_2d, atol=1e-6
|
| 217 |
+
), "Outputs are not consistent between 3D and 2D convolutions."
|
ltx_video/models/autoencoders/latent_upsampler.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from diffusers import ConfigMixin, ModelMixin
|
| 10 |
+
from safetensors.torch import safe_open
|
| 11 |
+
|
| 12 |
+
from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ResBlock(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
if mid_channels is None:
|
| 21 |
+
mid_channels = channels
|
| 22 |
+
|
| 23 |
+
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
| 24 |
+
|
| 25 |
+
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
| 26 |
+
self.norm1 = nn.GroupNorm(32, mid_channels)
|
| 27 |
+
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
| 28 |
+
self.norm2 = nn.GroupNorm(32, channels)
|
| 29 |
+
self.activation = nn.SiLU()
|
| 30 |
+
|
| 31 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
residual = x
|
| 33 |
+
x = self.conv1(x)
|
| 34 |
+
x = self.norm1(x)
|
| 35 |
+
x = self.activation(x)
|
| 36 |
+
x = self.conv2(x)
|
| 37 |
+
x = self.norm2(x)
|
| 38 |
+
x = self.activation(x + residual)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LatentUpsampler(ModelMixin, ConfigMixin):
|
| 43 |
+
"""
|
| 44 |
+
Model to spatially upsample VAE latents.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
in_channels (`int`): Number of channels in the input latent
|
| 48 |
+
mid_channels (`int`): Number of channels in the middle layers
|
| 49 |
+
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
| 50 |
+
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
| 51 |
+
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
| 52 |
+
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_channels: int = 128,
|
| 58 |
+
mid_channels: int = 512,
|
| 59 |
+
num_blocks_per_stage: int = 4,
|
| 60 |
+
dims: int = 3,
|
| 61 |
+
spatial_upsample: bool = True,
|
| 62 |
+
temporal_upsample: bool = False,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.in_channels = in_channels
|
| 67 |
+
self.mid_channels = mid_channels
|
| 68 |
+
self.num_blocks_per_stage = num_blocks_per_stage
|
| 69 |
+
self.dims = dims
|
| 70 |
+
self.spatial_upsample = spatial_upsample
|
| 71 |
+
self.temporal_upsample = temporal_upsample
|
| 72 |
+
|
| 73 |
+
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
|
| 74 |
+
|
| 75 |
+
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
| 76 |
+
self.initial_norm = nn.GroupNorm(32, mid_channels)
|
| 77 |
+
self.initial_activation = nn.SiLU()
|
| 78 |
+
|
| 79 |
+
self.res_blocks = nn.ModuleList(
|
| 80 |
+
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if spatial_upsample and temporal_upsample:
|
| 84 |
+
self.upsampler = nn.Sequential(
|
| 85 |
+
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
| 86 |
+
PixelShuffleND(3),
|
| 87 |
+
)
|
| 88 |
+
elif spatial_upsample:
|
| 89 |
+
self.upsampler = nn.Sequential(
|
| 90 |
+
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
| 91 |
+
PixelShuffleND(2),
|
| 92 |
+
)
|
| 93 |
+
elif temporal_upsample:
|
| 94 |
+
self.upsampler = nn.Sequential(
|
| 95 |
+
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
| 96 |
+
PixelShuffleND(1),
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
"Either spatial_upsample or temporal_upsample must be True"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.post_upsample_res_blocks = nn.ModuleList(
|
| 104 |
+
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
| 108 |
+
|
| 109 |
+
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
b, c, f, h, w = latent.shape
|
| 111 |
+
|
| 112 |
+
if self.dims == 2:
|
| 113 |
+
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
| 114 |
+
x = self.initial_conv(x)
|
| 115 |
+
x = self.initial_norm(x)
|
| 116 |
+
x = self.initial_activation(x)
|
| 117 |
+
|
| 118 |
+
for block in self.res_blocks:
|
| 119 |
+
x = block(x)
|
| 120 |
+
|
| 121 |
+
x = self.upsampler(x)
|
| 122 |
+
|
| 123 |
+
for block in self.post_upsample_res_blocks:
|
| 124 |
+
x = block(x)
|
| 125 |
+
|
| 126 |
+
x = self.final_conv(x)
|
| 127 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
| 128 |
+
else:
|
| 129 |
+
x = self.initial_conv(latent)
|
| 130 |
+
x = self.initial_norm(x)
|
| 131 |
+
x = self.initial_activation(x)
|
| 132 |
+
|
| 133 |
+
for block in self.res_blocks:
|
| 134 |
+
x = block(x)
|
| 135 |
+
|
| 136 |
+
if self.temporal_upsample:
|
| 137 |
+
x = self.upsampler(x)
|
| 138 |
+
x = x[:, :, 1:, :, :]
|
| 139 |
+
else:
|
| 140 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
| 141 |
+
x = self.upsampler(x)
|
| 142 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
| 143 |
+
|
| 144 |
+
for block in self.post_upsample_res_blocks:
|
| 145 |
+
x = block(x)
|
| 146 |
+
|
| 147 |
+
x = self.final_conv(x)
|
| 148 |
+
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
@classmethod
|
| 152 |
+
def from_config(cls, config):
|
| 153 |
+
return cls(
|
| 154 |
+
in_channels=config.get("in_channels", 4),
|
| 155 |
+
mid_channels=config.get("mid_channels", 128),
|
| 156 |
+
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
|
| 157 |
+
dims=config.get("dims", 2),
|
| 158 |
+
spatial_upsample=config.get("spatial_upsample", True),
|
| 159 |
+
temporal_upsample=config.get("temporal_upsample", False),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def config(self):
|
| 163 |
+
return {
|
| 164 |
+
"_class_name": "LatentUpsampler",
|
| 165 |
+
"in_channels": self.in_channels,
|
| 166 |
+
"mid_channels": self.mid_channels,
|
| 167 |
+
"num_blocks_per_stage": self.num_blocks_per_stage,
|
| 168 |
+
"dims": self.dims,
|
| 169 |
+
"spatial_upsample": self.spatial_upsample,
|
| 170 |
+
"temporal_upsample": self.temporal_upsample,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def from_pretrained(
|
| 175 |
+
cls,
|
| 176 |
+
pretrained_model_path: Optional[Union[str, os.PathLike]],
|
| 177 |
+
*args,
|
| 178 |
+
**kwargs,
|
| 179 |
+
):
|
| 180 |
+
pretrained_model_path = Path(pretrained_model_path)
|
| 181 |
+
if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
|
| 182 |
+
".safetensors"
|
| 183 |
+
):
|
| 184 |
+
state_dict = {}
|
| 185 |
+
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
|
| 186 |
+
metadata = f.metadata()
|
| 187 |
+
for k in f.keys():
|
| 188 |
+
state_dict[k] = f.get_tensor(k)
|
| 189 |
+
config = json.loads(metadata["config"])
|
| 190 |
+
with torch.device("meta"):
|
| 191 |
+
latent_upsampler = LatentUpsampler.from_config(config)
|
| 192 |
+
latent_upsampler.load_state_dict(state_dict, assign=True)
|
| 193 |
+
return latent_upsampler
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
|
| 198 |
+
print(latent_upsampler)
|
| 199 |
+
total_params = sum(p.numel() for p in latent_upsampler.parameters())
|
| 200 |
+
print(f"Total number of parameters: {total_params:,}")
|
| 201 |
+
latent = torch.randn(1, 128, 9, 16, 16)
|
| 202 |
+
upsampled_latent = latent_upsampler(latent)
|
| 203 |
+
print(f"Upsampled latent shape: {upsampled_latent.shape}")
|
ltx_video/models/autoencoders/pixel_norm.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PixelNorm(nn.Module):
|
| 6 |
+
def __init__(self, dim=1, eps=1e-8):
|
| 7 |
+
super(PixelNorm, self).__init__()
|
| 8 |
+
self.dim = dim
|
| 9 |
+
self.eps = eps
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
ltx_video/models/autoencoders/pixel_shuffle.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PixelShuffleND(nn.Module):
|
| 6 |
+
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
| 7 |
+
super().__init__()
|
| 8 |
+
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
| 9 |
+
self.dims = dims
|
| 10 |
+
self.upscale_factors = upscale_factors
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
if self.dims == 3:
|
| 14 |
+
return rearrange(
|
| 15 |
+
x,
|
| 16 |
+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
| 17 |
+
p1=self.upscale_factors[0],
|
| 18 |
+
p2=self.upscale_factors[1],
|
| 19 |
+
p3=self.upscale_factors[2],
|
| 20 |
+
)
|
| 21 |
+
elif self.dims == 2:
|
| 22 |
+
return rearrange(
|
| 23 |
+
x,
|
| 24 |
+
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
| 25 |
+
p1=self.upscale_factors[0],
|
| 26 |
+
p2=self.upscale_factors[1],
|
| 27 |
+
)
|
| 28 |
+
elif self.dims == 1:
|
| 29 |
+
return rearrange(
|
| 30 |
+
x,
|
| 31 |
+
"b (c p1) f h w -> b c (f p1) h w",
|
| 32 |
+
p1=self.upscale_factors[0],
|
| 33 |
+
)
|
ltx_video/models/autoencoders/vae.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import inspect
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from diffusers import ConfigMixin, ModelMixin
|
| 8 |
+
from diffusers.models.autoencoders.vae import (
|
| 9 |
+
DecoderOutput,
|
| 10 |
+
DiagonalGaussianDistribution,
|
| 11 |
+
)
|
| 12 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 13 |
+
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
| 17 |
+
"""Variational Autoencoder (VAE) model with KL loss.
|
| 18 |
+
|
| 19 |
+
VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
|
| 20 |
+
This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
encoder (`nn.Module`):
|
| 24 |
+
Encoder module.
|
| 25 |
+
decoder (`nn.Module`):
|
| 26 |
+
Decoder module.
|
| 27 |
+
latent_channels (`int`, *optional*, defaults to 4):
|
| 28 |
+
Number of latent channels.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
encoder: nn.Module,
|
| 34 |
+
decoder: nn.Module,
|
| 35 |
+
latent_channels: int = 4,
|
| 36 |
+
dims: int = 2,
|
| 37 |
+
sample_size=512,
|
| 38 |
+
use_quant_conv: bool = True,
|
| 39 |
+
normalize_latent_channels: bool = False,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
# pass init params to Encoder
|
| 44 |
+
self.encoder = encoder
|
| 45 |
+
self.use_quant_conv = use_quant_conv
|
| 46 |
+
self.normalize_latent_channels = normalize_latent_channels
|
| 47 |
+
|
| 48 |
+
# pass init params to Decoder
|
| 49 |
+
quant_dims = 2 if dims == 2 else 3
|
| 50 |
+
self.decoder = decoder
|
| 51 |
+
if use_quant_conv:
|
| 52 |
+
self.quant_conv = make_conv_nd(
|
| 53 |
+
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
|
| 54 |
+
)
|
| 55 |
+
self.post_quant_conv = make_conv_nd(
|
| 56 |
+
quant_dims, latent_channels, latent_channels, 1
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.quant_conv = nn.Identity()
|
| 60 |
+
self.post_quant_conv = nn.Identity()
|
| 61 |
+
|
| 62 |
+
if normalize_latent_channels:
|
| 63 |
+
if dims == 2:
|
| 64 |
+
self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
|
| 65 |
+
else:
|
| 66 |
+
self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
|
| 67 |
+
else:
|
| 68 |
+
self.latent_norm_out = nn.Identity()
|
| 69 |
+
self.use_z_tiling = False
|
| 70 |
+
self.use_hw_tiling = False
|
| 71 |
+
self.dims = dims
|
| 72 |
+
self.z_sample_size = 1
|
| 73 |
+
|
| 74 |
+
self.decoder_params = inspect.signature(self.decoder.forward).parameters
|
| 75 |
+
|
| 76 |
+
# only relevant if vae tiling is enabled
|
| 77 |
+
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
| 78 |
+
|
| 79 |
+
def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
|
| 80 |
+
self.tile_sample_min_size = sample_size
|
| 81 |
+
num_blocks = len(self.encoder.down_blocks)
|
| 82 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
|
| 83 |
+
self.tile_overlap_factor = overlap_factor
|
| 84 |
+
|
| 85 |
+
def enable_z_tiling(self, z_sample_size: int = 8):
|
| 86 |
+
r"""
|
| 87 |
+
Enable tiling during VAE decoding.
|
| 88 |
+
|
| 89 |
+
When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
|
| 90 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
| 91 |
+
"""
|
| 92 |
+
self.use_z_tiling = z_sample_size > 1
|
| 93 |
+
self.z_sample_size = z_sample_size
|
| 94 |
+
assert (
|
| 95 |
+
z_sample_size % 8 == 0 or z_sample_size == 1
|
| 96 |
+
), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
|
| 97 |
+
|
| 98 |
+
def disable_z_tiling(self):
|
| 99 |
+
r"""
|
| 100 |
+
Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
|
| 101 |
+
decoding in one step.
|
| 102 |
+
"""
|
| 103 |
+
self.use_z_tiling = False
|
| 104 |
+
|
| 105 |
+
def enable_hw_tiling(self):
|
| 106 |
+
r"""
|
| 107 |
+
Enable tiling during VAE decoding along the height and width dimension.
|
| 108 |
+
"""
|
| 109 |
+
self.use_hw_tiling = True
|
| 110 |
+
|
| 111 |
+
def disable_hw_tiling(self):
|
| 112 |
+
r"""
|
| 113 |
+
Disable tiling during VAE decoding along the height and width dimension.
|
| 114 |
+
"""
|
| 115 |
+
self.use_hw_tiling = False
|
| 116 |
+
|
| 117 |
+
def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
| 118 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 119 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 120 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 121 |
+
|
| 122 |
+
# Split the image into 512x512 tiles and encode them separately.
|
| 123 |
+
rows = []
|
| 124 |
+
for i in range(0, x.shape[3], overlap_size):
|
| 125 |
+
row = []
|
| 126 |
+
for j in range(0, x.shape[4], overlap_size):
|
| 127 |
+
tile = x[
|
| 128 |
+
:,
|
| 129 |
+
:,
|
| 130 |
+
:,
|
| 131 |
+
i : i + self.tile_sample_min_size,
|
| 132 |
+
j : j + self.tile_sample_min_size,
|
| 133 |
+
]
|
| 134 |
+
tile = self.encoder(tile)
|
| 135 |
+
tile = self.quant_conv(tile)
|
| 136 |
+
row.append(tile)
|
| 137 |
+
rows.append(row)
|
| 138 |
+
result_rows = []
|
| 139 |
+
for i, row in enumerate(rows):
|
| 140 |
+
result_row = []
|
| 141 |
+
for j, tile in enumerate(row):
|
| 142 |
+
# blend the above tile and the left tile
|
| 143 |
+
# to the current tile and add the current tile to the result row
|
| 144 |
+
if i > 0:
|
| 145 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 146 |
+
if j > 0:
|
| 147 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 148 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 149 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 150 |
+
|
| 151 |
+
moments = torch.cat(result_rows, dim=3)
|
| 152 |
+
return moments
|
| 153 |
+
|
| 154 |
+
def blend_z(
|
| 155 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
| 158 |
+
for z in range(blend_extent):
|
| 159 |
+
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
| 160 |
+
1 - z / blend_extent
|
| 161 |
+
) + b[:, :, z, :, :] * (z / blend_extent)
|
| 162 |
+
return b
|
| 163 |
+
|
| 164 |
+
def blend_v(
|
| 165 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 166 |
+
) -> torch.Tensor:
|
| 167 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 168 |
+
for y in range(blend_extent):
|
| 169 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
| 170 |
+
1 - y / blend_extent
|
| 171 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
| 172 |
+
return b
|
| 173 |
+
|
| 174 |
+
def blend_h(
|
| 175 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 178 |
+
for x in range(blend_extent):
|
| 179 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
| 180 |
+
1 - x / blend_extent
|
| 181 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 182 |
+
return b
|
| 183 |
+
|
| 184 |
+
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
| 185 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 186 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 187 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
| 188 |
+
tile_target_shape = (
|
| 189 |
+
*target_shape[:3],
|
| 190 |
+
self.tile_sample_min_size,
|
| 191 |
+
self.tile_sample_min_size,
|
| 192 |
+
)
|
| 193 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
| 194 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 195 |
+
rows = []
|
| 196 |
+
for i in range(0, z.shape[3], overlap_size):
|
| 197 |
+
row = []
|
| 198 |
+
for j in range(0, z.shape[4], overlap_size):
|
| 199 |
+
tile = z[
|
| 200 |
+
:,
|
| 201 |
+
:,
|
| 202 |
+
:,
|
| 203 |
+
i : i + self.tile_latent_min_size,
|
| 204 |
+
j : j + self.tile_latent_min_size,
|
| 205 |
+
]
|
| 206 |
+
tile = self.post_quant_conv(tile)
|
| 207 |
+
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
| 208 |
+
row.append(decoded)
|
| 209 |
+
rows.append(row)
|
| 210 |
+
result_rows = []
|
| 211 |
+
for i, row in enumerate(rows):
|
| 212 |
+
result_row = []
|
| 213 |
+
for j, tile in enumerate(row):
|
| 214 |
+
# blend the above tile and the left tile
|
| 215 |
+
# to the current tile and add the current tile to the result row
|
| 216 |
+
if i > 0:
|
| 217 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 218 |
+
if j > 0:
|
| 219 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 220 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 221 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 222 |
+
|
| 223 |
+
dec = torch.cat(result_rows, dim=3)
|
| 224 |
+
return dec
|
| 225 |
+
|
| 226 |
+
def encode(
|
| 227 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
| 228 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 229 |
+
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 230 |
+
num_splits = z.shape[2] // self.z_sample_size
|
| 231 |
+
sizes = [self.z_sample_size] * num_splits
|
| 232 |
+
sizes = (
|
| 233 |
+
sizes + [z.shape[2] - sum(sizes)]
|
| 234 |
+
if z.shape[2] - sum(sizes) > 0
|
| 235 |
+
else sizes
|
| 236 |
+
)
|
| 237 |
+
tiles = z.split(sizes, dim=2)
|
| 238 |
+
moments_tiles = [
|
| 239 |
+
(
|
| 240 |
+
self._hw_tiled_encode(z_tile, return_dict)
|
| 241 |
+
if self.use_hw_tiling
|
| 242 |
+
else self._encode(z_tile)
|
| 243 |
+
)
|
| 244 |
+
for z_tile in tiles
|
| 245 |
+
]
|
| 246 |
+
moments = torch.cat(moments_tiles, dim=2)
|
| 247 |
+
|
| 248 |
+
else:
|
| 249 |
+
moments = (
|
| 250 |
+
self._hw_tiled_encode(z, return_dict)
|
| 251 |
+
if self.use_hw_tiling
|
| 252 |
+
else self._encode(z)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 256 |
+
if not return_dict:
|
| 257 |
+
return (posterior,)
|
| 258 |
+
|
| 259 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 260 |
+
|
| 261 |
+
def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
| 262 |
+
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 263 |
+
_, c, _, _, _ = z.shape
|
| 264 |
+
z = torch.cat(
|
| 265 |
+
[
|
| 266 |
+
self.latent_norm_out(z[:, : c // 2, :, :, :]),
|
| 267 |
+
z[:, c // 2 :, :, :, :],
|
| 268 |
+
],
|
| 269 |
+
dim=1,
|
| 270 |
+
)
|
| 271 |
+
elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
|
| 272 |
+
raise NotImplementedError("BatchNorm2d not supported")
|
| 273 |
+
return z
|
| 274 |
+
|
| 275 |
+
def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
| 276 |
+
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 277 |
+
running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
|
| 278 |
+
running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
|
| 279 |
+
eps = self.latent_norm_out.eps
|
| 280 |
+
|
| 281 |
+
z = z * torch.sqrt(running_var + eps) + running_mean
|
| 282 |
+
elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 283 |
+
raise NotImplementedError("BatchNorm2d not supported")
|
| 284 |
+
return z
|
| 285 |
+
|
| 286 |
+
def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
|
| 287 |
+
h = self.encoder(x)
|
| 288 |
+
moments = self.quant_conv(h)
|
| 289 |
+
moments = self._normalize_latent_channels(moments)
|
| 290 |
+
return moments
|
| 291 |
+
|
| 292 |
+
def _decode(
|
| 293 |
+
self,
|
| 294 |
+
z: torch.FloatTensor,
|
| 295 |
+
target_shape=None,
|
| 296 |
+
timestep: Optional[torch.Tensor] = None,
|
| 297 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 298 |
+
z = self._unnormalize_latent_channels(z)
|
| 299 |
+
z = self.post_quant_conv(z)
|
| 300 |
+
if "timestep" in self.decoder_params:
|
| 301 |
+
dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
|
| 302 |
+
else:
|
| 303 |
+
dec = self.decoder(z, target_shape=target_shape)
|
| 304 |
+
return dec
|
| 305 |
+
|
| 306 |
+
def decode(
|
| 307 |
+
self,
|
| 308 |
+
z: torch.FloatTensor,
|
| 309 |
+
return_dict: bool = True,
|
| 310 |
+
target_shape=None,
|
| 311 |
+
timestep: Optional[torch.Tensor] = None,
|
| 312 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 313 |
+
assert target_shape is not None, "target_shape must be provided for decoding"
|
| 314 |
+
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 315 |
+
reduction_factor = int(
|
| 316 |
+
self.encoder.patch_size_t
|
| 317 |
+
* 2
|
| 318 |
+
** (
|
| 319 |
+
len(self.encoder.down_blocks)
|
| 320 |
+
- 1
|
| 321 |
+
- math.sqrt(self.encoder.patch_size)
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
split_size = self.z_sample_size // reduction_factor
|
| 325 |
+
num_splits = z.shape[2] // split_size
|
| 326 |
+
|
| 327 |
+
# copy target shape, and divide frame dimension (=2) by the context size
|
| 328 |
+
target_shape_split = list(target_shape)
|
| 329 |
+
target_shape_split[2] = target_shape[2] // num_splits
|
| 330 |
+
|
| 331 |
+
decoded_tiles = [
|
| 332 |
+
(
|
| 333 |
+
self._hw_tiled_decode(z_tile, target_shape_split)
|
| 334 |
+
if self.use_hw_tiling
|
| 335 |
+
else self._decode(z_tile, target_shape=target_shape_split)
|
| 336 |
+
)
|
| 337 |
+
for z_tile in torch.tensor_split(z, num_splits, dim=2)
|
| 338 |
+
]
|
| 339 |
+
decoded = torch.cat(decoded_tiles, dim=2)
|
| 340 |
+
else:
|
| 341 |
+
decoded = (
|
| 342 |
+
self._hw_tiled_decode(z, target_shape)
|
| 343 |
+
if self.use_hw_tiling
|
| 344 |
+
else self._decode(z, target_shape=target_shape, timestep=timestep)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if not return_dict:
|
| 348 |
+
return (decoded,)
|
| 349 |
+
|
| 350 |
+
return DecoderOutput(sample=decoded)
|
| 351 |
+
|
| 352 |
+
def forward(
|
| 353 |
+
self,
|
| 354 |
+
sample: torch.FloatTensor,
|
| 355 |
+
sample_posterior: bool = False,
|
| 356 |
+
return_dict: bool = True,
|
| 357 |
+
generator: Optional[torch.Generator] = None,
|
| 358 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 359 |
+
r"""
|
| 360 |
+
Args:
|
| 361 |
+
sample (`torch.FloatTensor`): Input sample.
|
| 362 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 363 |
+
Whether to sample from the posterior.
|
| 364 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 365 |
+
Whether to return a [`DecoderOutput`] instead of a plain tuple.
|
| 366 |
+
generator (`torch.Generator`, *optional*):
|
| 367 |
+
Generator used to sample from the posterior.
|
| 368 |
+
"""
|
| 369 |
+
x = sample
|
| 370 |
+
posterior = self.encode(x).latent_dist
|
| 371 |
+
if sample_posterior:
|
| 372 |
+
z = posterior.sample(generator=generator)
|
| 373 |
+
else:
|
| 374 |
+
z = posterior.mode()
|
| 375 |
+
dec = self.decode(z, target_shape=sample.shape).sample
|
| 376 |
+
|
| 377 |
+
if not return_dict:
|
| 378 |
+
return (dec,)
|
| 379 |
+
|
| 380 |
+
return DecoderOutput(sample=dec)
|
ltx_video/models/autoencoders/vae_encode.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import AutoencoderKL
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| 9 |
+
CausalVideoAutoencoder,
|
| 10 |
+
)
|
| 11 |
+
from ltx_video.models.autoencoders.video_autoencoder import (
|
| 12 |
+
Downsample3D,
|
| 13 |
+
VideoAutoencoder,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch_xla.core.xla_model as xm
|
| 18 |
+
except ImportError:
|
| 19 |
+
xm = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def vae_encode(
|
| 23 |
+
media_items: Tensor,
|
| 24 |
+
vae: AutoencoderKL,
|
| 25 |
+
split_size: int = 1,
|
| 26 |
+
vae_per_channel_normalize=False,
|
| 27 |
+
) -> Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
| 30 |
+
The function supports processing batches of images or video frames and can handle the processing
|
| 31 |
+
in smaller sub-batches if needed.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
|
| 35 |
+
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
|
| 36 |
+
frames, height, width) for videos.
|
| 37 |
+
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
|
| 38 |
+
pre-configured and loaded with the appropriate model weights.
|
| 39 |
+
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
|
| 40 |
+
If set to more than 1, the input media items are processed in smaller batches according to
|
| 41 |
+
this value. Defaults to 1, which processes all items in a single batch.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
|
| 45 |
+
to match the input shape, scaled by the model's configuration.
|
| 46 |
+
|
| 47 |
+
Examples:
|
| 48 |
+
>>> import torch
|
| 49 |
+
>>> from diffusers import AutoencoderKL
|
| 50 |
+
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
|
| 51 |
+
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
|
| 52 |
+
>>> latents = vae_encode(images, vae)
|
| 53 |
+
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
|
| 54 |
+
|
| 55 |
+
Note:
|
| 56 |
+
In case of a video, the function encodes the media item frame-by frame.
|
| 57 |
+
"""
|
| 58 |
+
is_video_shaped = media_items.dim() == 5
|
| 59 |
+
batch_size, channels = media_items.shape[0:2]
|
| 60 |
+
|
| 61 |
+
if channels != 3:
|
| 62 |
+
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 63 |
+
|
| 64 |
+
if is_video_shaped and not isinstance(
|
| 65 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 66 |
+
):
|
| 67 |
+
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 68 |
+
if split_size > 1:
|
| 69 |
+
if len(media_items) % split_size != 0:
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 72 |
+
)
|
| 73 |
+
encode_bs = len(media_items) // split_size
|
| 74 |
+
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 75 |
+
latents = []
|
| 76 |
+
if media_items.device.type == "xla":
|
| 77 |
+
xm.mark_step()
|
| 78 |
+
for image_batch in media_items.split(encode_bs):
|
| 79 |
+
latents.append(vae.encode(image_batch).latent_dist.sample())
|
| 80 |
+
if media_items.device.type == "xla":
|
| 81 |
+
xm.mark_step()
|
| 82 |
+
latents = torch.cat(latents, dim=0)
|
| 83 |
+
else:
|
| 84 |
+
latents = vae.encode(media_items).latent_dist.sample()
|
| 85 |
+
|
| 86 |
+
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 87 |
+
if is_video_shaped and not isinstance(
|
| 88 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 89 |
+
):
|
| 90 |
+
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 91 |
+
return latents
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def vae_decode(
|
| 95 |
+
latents: Tensor,
|
| 96 |
+
vae: AutoencoderKL,
|
| 97 |
+
is_video: bool = True,
|
| 98 |
+
split_size: int = 1,
|
| 99 |
+
vae_per_channel_normalize=False,
|
| 100 |
+
timestep=None,
|
| 101 |
+
) -> Tensor:
|
| 102 |
+
is_video_shaped = latents.dim() == 5
|
| 103 |
+
batch_size = latents.shape[0]
|
| 104 |
+
|
| 105 |
+
if is_video_shaped and not isinstance(
|
| 106 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 107 |
+
):
|
| 108 |
+
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 109 |
+
if split_size > 1:
|
| 110 |
+
if len(latents) % split_size != 0:
|
| 111 |
+
raise ValueError(
|
| 112 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 113 |
+
)
|
| 114 |
+
encode_bs = len(latents) // split_size
|
| 115 |
+
image_batch = [
|
| 116 |
+
_run_decoder(
|
| 117 |
+
latent_batch, vae, is_video, vae_per_channel_normalize, timestep
|
| 118 |
+
)
|
| 119 |
+
for latent_batch in latents.split(encode_bs)
|
| 120 |
+
]
|
| 121 |
+
images = torch.cat(image_batch, dim=0)
|
| 122 |
+
else:
|
| 123 |
+
images = _run_decoder(
|
| 124 |
+
latents, vae, is_video, vae_per_channel_normalize, timestep
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if is_video_shaped and not isinstance(
|
| 128 |
+
vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 129 |
+
):
|
| 130 |
+
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 131 |
+
return images
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _run_decoder(
|
| 135 |
+
latents: Tensor,
|
| 136 |
+
vae: AutoencoderKL,
|
| 137 |
+
is_video: bool,
|
| 138 |
+
vae_per_channel_normalize=False,
|
| 139 |
+
timestep=None,
|
| 140 |
+
) -> Tensor:
|
| 141 |
+
if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 142 |
+
*_, fl, hl, wl = latents.shape
|
| 143 |
+
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
| 144 |
+
latents = latents.to(vae.dtype)
|
| 145 |
+
vae_decode_kwargs = {}
|
| 146 |
+
if timestep is not None:
|
| 147 |
+
vae_decode_kwargs["timestep"] = timestep
|
| 148 |
+
image = vae.decode(
|
| 149 |
+
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 150 |
+
return_dict=False,
|
| 151 |
+
target_shape=(
|
| 152 |
+
1,
|
| 153 |
+
3,
|
| 154 |
+
fl * temporal_scale if is_video else 1,
|
| 155 |
+
hl * spatial_scale,
|
| 156 |
+
wl * spatial_scale,
|
| 157 |
+
),
|
| 158 |
+
**vae_decode_kwargs,
|
| 159 |
+
)[0]
|
| 160 |
+
else:
|
| 161 |
+
image = vae.decode(
|
| 162 |
+
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 163 |
+
return_dict=False,
|
| 164 |
+
)[0]
|
| 165 |
+
return image
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
| 169 |
+
if isinstance(vae, CausalVideoAutoencoder):
|
| 170 |
+
spatial = vae.spatial_downscale_factor
|
| 171 |
+
temporal = vae.temporal_downscale_factor
|
| 172 |
+
else:
|
| 173 |
+
down_blocks = len(
|
| 174 |
+
[
|
| 175 |
+
block
|
| 176 |
+
for block in vae.encoder.down_blocks
|
| 177 |
+
if isinstance(block.downsample, Downsample3D)
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
spatial = vae.config.patch_size * 2**down_blocks
|
| 181 |
+
temporal = (
|
| 182 |
+
vae.config.patch_size_t * 2**down_blocks
|
| 183 |
+
if isinstance(vae, VideoAutoencoder)
|
| 184 |
+
else 1
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return (temporal, spatial, spatial)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def latent_to_pixel_coords(
|
| 191 |
+
latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
"""
|
| 194 |
+
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
| 195 |
+
configuration.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
| 199 |
+
containing the latent corner coordinates of each token.
|
| 200 |
+
vae (AutoencoderKL): The VAE model
|
| 201 |
+
causal_fix (bool): Whether to take into account the different temporal scale
|
| 202 |
+
of the first frame. Default = False for backwards compatibility.
|
| 203 |
+
Returns:
|
| 204 |
+
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
scale_factors = get_vae_size_scale_factor(vae)
|
| 208 |
+
causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
|
| 209 |
+
pixel_coords = latent_to_pixel_coords_from_factors(
|
| 210 |
+
latent_coords, scale_factors, causal_fix
|
| 211 |
+
)
|
| 212 |
+
return pixel_coords
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def latent_to_pixel_coords_from_factors(
|
| 216 |
+
latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
|
| 217 |
+
) -> Tensor:
|
| 218 |
+
pixel_coords = (
|
| 219 |
+
latent_coords
|
| 220 |
+
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
| 221 |
+
)
|
| 222 |
+
if causal_fix:
|
| 223 |
+
# Fix temporal scale for first frame to 1 due to causality
|
| 224 |
+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
| 225 |
+
return pixel_coords
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def normalize_latents(
|
| 229 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 230 |
+
) -> Tensor:
|
| 231 |
+
return (
|
| 232 |
+
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
| 233 |
+
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 234 |
+
if vae_per_channel_normalize
|
| 235 |
+
else latents * vae.config.scaling_factor
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def un_normalize_latents(
|
| 240 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 241 |
+
) -> Tensor:
|
| 242 |
+
return (
|
| 243 |
+
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 244 |
+
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 245 |
+
if vae_per_channel_normalize
|
| 246 |
+
else latents / vae.config.scaling_factor
|
| 247 |
+
)
|
ltx_video/models/autoencoders/video_autoencoder.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Any, Mapping, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional
|
| 11 |
+
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
|
| 14 |
+
from ltx_video.utils.torch_utils import Identity
|
| 15 |
+
from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 16 |
+
from ltx_video.models.autoencoders.pixel_norm import PixelNorm
|
| 17 |
+
from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VideoAutoencoder(AutoencoderKLWrapper):
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_pretrained(
|
| 25 |
+
cls,
|
| 26 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 27 |
+
*args,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 31 |
+
config = cls.load_config(config_local_path, **kwargs)
|
| 32 |
+
video_vae = cls.from_config(config)
|
| 33 |
+
video_vae.to(kwargs["torch_dtype"])
|
| 34 |
+
|
| 35 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
| 36 |
+
ckpt_state_dict = torch.load(model_local_path)
|
| 37 |
+
video_vae.load_state_dict(ckpt_state_dict)
|
| 38 |
+
|
| 39 |
+
statistics_local_path = (
|
| 40 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 41 |
+
)
|
| 42 |
+
if statistics_local_path.exists():
|
| 43 |
+
with open(statistics_local_path, "r") as file:
|
| 44 |
+
data = json.load(file)
|
| 45 |
+
transposed_data = list(zip(*data["data"]))
|
| 46 |
+
data_dict = {
|
| 47 |
+
col: torch.tensor(vals)
|
| 48 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 49 |
+
}
|
| 50 |
+
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 51 |
+
video_vae.register_buffer(
|
| 52 |
+
"mean_of_means",
|
| 53 |
+
data_dict.get(
|
| 54 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return video_vae
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def from_config(config):
|
| 62 |
+
assert (
|
| 63 |
+
config["_class_name"] == "VideoAutoencoder"
|
| 64 |
+
), "config must have _class_name=VideoAutoencoder"
|
| 65 |
+
if isinstance(config["dims"], list):
|
| 66 |
+
config["dims"] = tuple(config["dims"])
|
| 67 |
+
|
| 68 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 69 |
+
|
| 70 |
+
double_z = config.get("double_z", True)
|
| 71 |
+
latent_log_var = config.get(
|
| 72 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 73 |
+
)
|
| 74 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
| 75 |
+
|
| 76 |
+
if use_quant_conv and latent_log_var == "uniform":
|
| 77 |
+
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
|
| 78 |
+
|
| 79 |
+
encoder = Encoder(
|
| 80 |
+
dims=config["dims"],
|
| 81 |
+
in_channels=config.get("in_channels", 3),
|
| 82 |
+
out_channels=config["latent_channels"],
|
| 83 |
+
block_out_channels=config["block_out_channels"],
|
| 84 |
+
patch_size=config.get("patch_size", 1),
|
| 85 |
+
latent_log_var=latent_log_var,
|
| 86 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 87 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 88 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
decoder = Decoder(
|
| 92 |
+
dims=config["dims"],
|
| 93 |
+
in_channels=config["latent_channels"],
|
| 94 |
+
out_channels=config.get("out_channels", 3),
|
| 95 |
+
block_out_channels=config["block_out_channels"],
|
| 96 |
+
patch_size=config.get("patch_size", 1),
|
| 97 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 98 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 99 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
dims = config["dims"]
|
| 103 |
+
return VideoAutoencoder(
|
| 104 |
+
encoder=encoder,
|
| 105 |
+
decoder=decoder,
|
| 106 |
+
latent_channels=config["latent_channels"],
|
| 107 |
+
dims=dims,
|
| 108 |
+
use_quant_conv=use_quant_conv,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def config(self):
|
| 113 |
+
return SimpleNamespace(
|
| 114 |
+
_class_name="VideoAutoencoder",
|
| 115 |
+
dims=self.dims,
|
| 116 |
+
in_channels=self.encoder.conv_in.in_channels
|
| 117 |
+
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
| 118 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 119 |
+
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
| 120 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
| 121 |
+
block_out_channels=[
|
| 122 |
+
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
| 123 |
+
for i in range(len(self.encoder.down_blocks))
|
| 124 |
+
],
|
| 125 |
+
scaling_factor=1.0,
|
| 126 |
+
norm_layer=self.encoder.norm_layer,
|
| 127 |
+
patch_size=self.encoder.patch_size,
|
| 128 |
+
latent_log_var=self.encoder.latent_log_var,
|
| 129 |
+
use_quant_conv=self.use_quant_conv,
|
| 130 |
+
patch_size_t=self.encoder.patch_size_t,
|
| 131 |
+
add_channel_padding=self.encoder.add_channel_padding,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def is_video_supported(self):
|
| 136 |
+
"""
|
| 137 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
| 138 |
+
"""
|
| 139 |
+
return self.dims != 2
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def downscale_factor(self):
|
| 143 |
+
return self.encoder.downsample_factor
|
| 144 |
+
|
| 145 |
+
def to_json_string(self) -> str:
|
| 146 |
+
import json
|
| 147 |
+
|
| 148 |
+
return json.dumps(self.config.__dict__)
|
| 149 |
+
|
| 150 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 151 |
+
model_keys = set(name for name, _ in self.named_parameters())
|
| 152 |
+
|
| 153 |
+
key_mapping = {
|
| 154 |
+
".resnets.": ".res_blocks.",
|
| 155 |
+
"downsamplers.0": "downsample",
|
| 156 |
+
"upsamplers.0": "upsample",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
converted_state_dict = {}
|
| 160 |
+
for key, value in state_dict.items():
|
| 161 |
+
for k, v in key_mapping.items():
|
| 162 |
+
key = key.replace(k, v)
|
| 163 |
+
|
| 164 |
+
if "norm" in key and key not in model_keys:
|
| 165 |
+
logger.info(
|
| 166 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 167 |
+
)
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
converted_state_dict[key] = value
|
| 171 |
+
|
| 172 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
| 173 |
+
|
| 174 |
+
def last_layer(self):
|
| 175 |
+
if hasattr(self.decoder, "conv_out"):
|
| 176 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
| 177 |
+
last_layer = self.decoder.conv_out[-1]
|
| 178 |
+
else:
|
| 179 |
+
last_layer = self.decoder.conv_out
|
| 180 |
+
else:
|
| 181 |
+
last_layer = self.decoder.layers[-1]
|
| 182 |
+
return last_layer
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Encoder(nn.Module):
|
| 186 |
+
r"""
|
| 187 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 191 |
+
The number of input channels.
|
| 192 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 193 |
+
The number of output channels.
|
| 194 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 195 |
+
The number of output channels for each block.
|
| 196 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 197 |
+
The number of layers per block.
|
| 198 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 199 |
+
The number of groups for normalization.
|
| 200 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 201 |
+
The patch size to use. Should be a power of 2.
|
| 202 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 203 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 204 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
| 205 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
| 211 |
+
in_channels: int = 3,
|
| 212 |
+
out_channels: int = 3,
|
| 213 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 214 |
+
layers_per_block: int = 2,
|
| 215 |
+
norm_num_groups: int = 32,
|
| 216 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
| 217 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
| 218 |
+
latent_log_var: str = "per_channel",
|
| 219 |
+
patch_size_t: Optional[int] = None,
|
| 220 |
+
add_channel_padding: Optional[bool] = False,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.patch_size = patch_size
|
| 224 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 225 |
+
self.add_channel_padding = add_channel_padding
|
| 226 |
+
self.layers_per_block = layers_per_block
|
| 227 |
+
self.norm_layer = norm_layer
|
| 228 |
+
self.latent_channels = out_channels
|
| 229 |
+
self.latent_log_var = latent_log_var
|
| 230 |
+
if add_channel_padding:
|
| 231 |
+
in_channels = in_channels * self.patch_size**3
|
| 232 |
+
else:
|
| 233 |
+
in_channels = in_channels * self.patch_size_t * self.patch_size**2
|
| 234 |
+
self.in_channels = in_channels
|
| 235 |
+
output_channel = block_out_channels[0]
|
| 236 |
+
|
| 237 |
+
self.conv_in = make_conv_nd(
|
| 238 |
+
dims=dims,
|
| 239 |
+
in_channels=in_channels,
|
| 240 |
+
out_channels=output_channel,
|
| 241 |
+
kernel_size=3,
|
| 242 |
+
stride=1,
|
| 243 |
+
padding=1,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.down_blocks = nn.ModuleList([])
|
| 247 |
+
|
| 248 |
+
for i in range(len(block_out_channels)):
|
| 249 |
+
input_channel = output_channel
|
| 250 |
+
output_channel = block_out_channels[i]
|
| 251 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 252 |
+
|
| 253 |
+
down_block = DownEncoderBlock3D(
|
| 254 |
+
dims=dims,
|
| 255 |
+
in_channels=input_channel,
|
| 256 |
+
out_channels=output_channel,
|
| 257 |
+
num_layers=self.layers_per_block,
|
| 258 |
+
add_downsample=not is_final_block and 2**i >= patch_size,
|
| 259 |
+
resnet_eps=1e-6,
|
| 260 |
+
downsample_padding=0,
|
| 261 |
+
resnet_groups=norm_num_groups,
|
| 262 |
+
norm_layer=norm_layer,
|
| 263 |
+
)
|
| 264 |
+
self.down_blocks.append(down_block)
|
| 265 |
+
|
| 266 |
+
self.mid_block = UNetMidBlock3D(
|
| 267 |
+
dims=dims,
|
| 268 |
+
in_channels=block_out_channels[-1],
|
| 269 |
+
num_layers=self.layers_per_block,
|
| 270 |
+
resnet_eps=1e-6,
|
| 271 |
+
resnet_groups=norm_num_groups,
|
| 272 |
+
norm_layer=norm_layer,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# out
|
| 276 |
+
if norm_layer == "group_norm":
|
| 277 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 278 |
+
num_channels=block_out_channels[-1],
|
| 279 |
+
num_groups=norm_num_groups,
|
| 280 |
+
eps=1e-6,
|
| 281 |
+
)
|
| 282 |
+
elif norm_layer == "pixel_norm":
|
| 283 |
+
self.conv_norm_out = PixelNorm()
|
| 284 |
+
self.conv_act = nn.SiLU()
|
| 285 |
+
|
| 286 |
+
conv_out_channels = out_channels
|
| 287 |
+
if latent_log_var == "per_channel":
|
| 288 |
+
conv_out_channels *= 2
|
| 289 |
+
elif latent_log_var == "uniform":
|
| 290 |
+
conv_out_channels += 1
|
| 291 |
+
elif latent_log_var != "none":
|
| 292 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 293 |
+
self.conv_out = make_conv_nd(
|
| 294 |
+
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.gradient_checkpointing = False
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def downscale_factor(self):
|
| 301 |
+
return (
|
| 302 |
+
2
|
| 303 |
+
** len(
|
| 304 |
+
[
|
| 305 |
+
block
|
| 306 |
+
for block in self.down_blocks
|
| 307 |
+
if isinstance(block.downsample, Downsample3D)
|
| 308 |
+
]
|
| 309 |
+
)
|
| 310 |
+
* self.patch_size
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self, sample: torch.FloatTensor, return_features=False
|
| 315 |
+
) -> torch.FloatTensor:
|
| 316 |
+
r"""The forward method of the `Encoder` class."""
|
| 317 |
+
|
| 318 |
+
downsample_in_time = sample.shape[2] != 1
|
| 319 |
+
|
| 320 |
+
# patchify
|
| 321 |
+
patch_size_t = self.patch_size_t if downsample_in_time else 1
|
| 322 |
+
sample = patchify(
|
| 323 |
+
sample,
|
| 324 |
+
patch_size_hw=self.patch_size,
|
| 325 |
+
patch_size_t=patch_size_t,
|
| 326 |
+
add_channel_padding=self.add_channel_padding,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
sample = self.conv_in(sample)
|
| 330 |
+
|
| 331 |
+
checkpoint_fn = (
|
| 332 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 333 |
+
if self.gradient_checkpointing and self.training
|
| 334 |
+
else lambda x: x
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if return_features:
|
| 338 |
+
features = []
|
| 339 |
+
for down_block in self.down_blocks:
|
| 340 |
+
sample = checkpoint_fn(down_block)(
|
| 341 |
+
sample, downsample_in_time=downsample_in_time
|
| 342 |
+
)
|
| 343 |
+
if return_features:
|
| 344 |
+
features.append(sample)
|
| 345 |
+
|
| 346 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 347 |
+
|
| 348 |
+
# post-process
|
| 349 |
+
sample = self.conv_norm_out(sample)
|
| 350 |
+
sample = self.conv_act(sample)
|
| 351 |
+
sample = self.conv_out(sample)
|
| 352 |
+
|
| 353 |
+
if self.latent_log_var == "uniform":
|
| 354 |
+
last_channel = sample[:, -1:, ...]
|
| 355 |
+
num_dims = sample.dim()
|
| 356 |
+
|
| 357 |
+
if num_dims == 4:
|
| 358 |
+
# For shape (B, C, H, W)
|
| 359 |
+
repeated_last_channel = last_channel.repeat(
|
| 360 |
+
1, sample.shape[1] - 2, 1, 1
|
| 361 |
+
)
|
| 362 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 363 |
+
elif num_dims == 5:
|
| 364 |
+
# For shape (B, C, F, H, W)
|
| 365 |
+
repeated_last_channel = last_channel.repeat(
|
| 366 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 367 |
+
)
|
| 368 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 371 |
+
|
| 372 |
+
if return_features:
|
| 373 |
+
features.append(sample[:, : self.latent_channels, ...])
|
| 374 |
+
return sample, features
|
| 375 |
+
return sample
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Decoder(nn.Module):
|
| 379 |
+
r"""
|
| 380 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 384 |
+
The number of input channels.
|
| 385 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 386 |
+
The number of output channels.
|
| 387 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 388 |
+
The number of output channels for each block.
|
| 389 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 390 |
+
The number of layers per block.
|
| 391 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 392 |
+
The number of groups for normalization.
|
| 393 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 394 |
+
The patch size to use. Should be a power of 2.
|
| 395 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 396 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
dims,
|
| 402 |
+
in_channels: int = 3,
|
| 403 |
+
out_channels: int = 3,
|
| 404 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 405 |
+
layers_per_block: int = 2,
|
| 406 |
+
norm_num_groups: int = 32,
|
| 407 |
+
patch_size: int = 1,
|
| 408 |
+
norm_layer: str = "group_norm",
|
| 409 |
+
patch_size_t: Optional[int] = None,
|
| 410 |
+
add_channel_padding: Optional[bool] = False,
|
| 411 |
+
):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.patch_size = patch_size
|
| 414 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 415 |
+
self.add_channel_padding = add_channel_padding
|
| 416 |
+
self.layers_per_block = layers_per_block
|
| 417 |
+
if add_channel_padding:
|
| 418 |
+
out_channels = out_channels * self.patch_size**3
|
| 419 |
+
else:
|
| 420 |
+
out_channels = out_channels * self.patch_size_t * self.patch_size**2
|
| 421 |
+
self.out_channels = out_channels
|
| 422 |
+
|
| 423 |
+
self.conv_in = make_conv_nd(
|
| 424 |
+
dims,
|
| 425 |
+
in_channels,
|
| 426 |
+
block_out_channels[-1],
|
| 427 |
+
kernel_size=3,
|
| 428 |
+
stride=1,
|
| 429 |
+
padding=1,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.mid_block = None
|
| 433 |
+
self.up_blocks = nn.ModuleList([])
|
| 434 |
+
|
| 435 |
+
self.mid_block = UNetMidBlock3D(
|
| 436 |
+
dims=dims,
|
| 437 |
+
in_channels=block_out_channels[-1],
|
| 438 |
+
num_layers=self.layers_per_block,
|
| 439 |
+
resnet_eps=1e-6,
|
| 440 |
+
resnet_groups=norm_num_groups,
|
| 441 |
+
norm_layer=norm_layer,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 445 |
+
output_channel = reversed_block_out_channels[0]
|
| 446 |
+
for i in range(len(reversed_block_out_channels)):
|
| 447 |
+
prev_output_channel = output_channel
|
| 448 |
+
output_channel = reversed_block_out_channels[i]
|
| 449 |
+
|
| 450 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 451 |
+
|
| 452 |
+
up_block = UpDecoderBlock3D(
|
| 453 |
+
dims=dims,
|
| 454 |
+
num_layers=self.layers_per_block + 1,
|
| 455 |
+
in_channels=prev_output_channel,
|
| 456 |
+
out_channels=output_channel,
|
| 457 |
+
add_upsample=not is_final_block
|
| 458 |
+
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
| 459 |
+
resnet_eps=1e-6,
|
| 460 |
+
resnet_groups=norm_num_groups,
|
| 461 |
+
norm_layer=norm_layer,
|
| 462 |
+
)
|
| 463 |
+
self.up_blocks.append(up_block)
|
| 464 |
+
|
| 465 |
+
if norm_layer == "group_norm":
|
| 466 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 467 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
| 468 |
+
)
|
| 469 |
+
elif norm_layer == "pixel_norm":
|
| 470 |
+
self.conv_norm_out = PixelNorm()
|
| 471 |
+
|
| 472 |
+
self.conv_act = nn.SiLU()
|
| 473 |
+
self.conv_out = make_conv_nd(
|
| 474 |
+
dims, block_out_channels[0], out_channels, 3, padding=1
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
self.gradient_checkpointing = False
|
| 478 |
+
|
| 479 |
+
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
| 480 |
+
r"""The forward method of the `Decoder` class."""
|
| 481 |
+
assert target_shape is not None, "target_shape must be provided"
|
| 482 |
+
upsample_in_time = sample.shape[2] < target_shape[2]
|
| 483 |
+
|
| 484 |
+
sample = self.conv_in(sample)
|
| 485 |
+
|
| 486 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 487 |
+
|
| 488 |
+
checkpoint_fn = (
|
| 489 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 490 |
+
if self.gradient_checkpointing and self.training
|
| 491 |
+
else lambda x: x
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 495 |
+
sample = sample.to(upscale_dtype)
|
| 496 |
+
|
| 497 |
+
for up_block in self.up_blocks:
|
| 498 |
+
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
|
| 499 |
+
|
| 500 |
+
# post-process
|
| 501 |
+
sample = self.conv_norm_out(sample)
|
| 502 |
+
sample = self.conv_act(sample)
|
| 503 |
+
sample = self.conv_out(sample)
|
| 504 |
+
|
| 505 |
+
# un-patchify
|
| 506 |
+
patch_size_t = self.patch_size_t if upsample_in_time else 1
|
| 507 |
+
sample = unpatchify(
|
| 508 |
+
sample,
|
| 509 |
+
patch_size_hw=self.patch_size,
|
| 510 |
+
patch_size_t=patch_size_t,
|
| 511 |
+
add_channel_padding=self.add_channel_padding,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
return sample
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
class DownEncoderBlock3D(nn.Module):
|
| 518 |
+
def __init__(
|
| 519 |
+
self,
|
| 520 |
+
dims: Union[int, Tuple[int, int]],
|
| 521 |
+
in_channels: int,
|
| 522 |
+
out_channels: int,
|
| 523 |
+
dropout: float = 0.0,
|
| 524 |
+
num_layers: int = 1,
|
| 525 |
+
resnet_eps: float = 1e-6,
|
| 526 |
+
resnet_groups: int = 32,
|
| 527 |
+
add_downsample: bool = True,
|
| 528 |
+
downsample_padding: int = 1,
|
| 529 |
+
norm_layer: str = "group_norm",
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
res_blocks = []
|
| 533 |
+
|
| 534 |
+
for i in range(num_layers):
|
| 535 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 536 |
+
res_blocks.append(
|
| 537 |
+
ResnetBlock3D(
|
| 538 |
+
dims=dims,
|
| 539 |
+
in_channels=in_channels,
|
| 540 |
+
out_channels=out_channels,
|
| 541 |
+
eps=resnet_eps,
|
| 542 |
+
groups=resnet_groups,
|
| 543 |
+
dropout=dropout,
|
| 544 |
+
norm_layer=norm_layer,
|
| 545 |
+
)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 549 |
+
|
| 550 |
+
if add_downsample:
|
| 551 |
+
self.downsample = Downsample3D(
|
| 552 |
+
dims,
|
| 553 |
+
out_channels,
|
| 554 |
+
out_channels=out_channels,
|
| 555 |
+
padding=downsample_padding,
|
| 556 |
+
)
|
| 557 |
+
else:
|
| 558 |
+
self.downsample = Identity()
|
| 559 |
+
|
| 560 |
+
def forward(
|
| 561 |
+
self, hidden_states: torch.FloatTensor, downsample_in_time
|
| 562 |
+
) -> torch.FloatTensor:
|
| 563 |
+
for resnet in self.res_blocks:
|
| 564 |
+
hidden_states = resnet(hidden_states)
|
| 565 |
+
|
| 566 |
+
hidden_states = self.downsample(
|
| 567 |
+
hidden_states, downsample_in_time=downsample_in_time
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
return hidden_states
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class UNetMidBlock3D(nn.Module):
|
| 574 |
+
"""
|
| 575 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
in_channels (`int`): The number of input channels.
|
| 579 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 580 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 581 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 582 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 583 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 587 |
+
in_channels, height, width)`.
|
| 588 |
+
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
dims: Union[int, Tuple[int, int]],
|
| 594 |
+
in_channels: int,
|
| 595 |
+
dropout: float = 0.0,
|
| 596 |
+
num_layers: int = 1,
|
| 597 |
+
resnet_eps: float = 1e-6,
|
| 598 |
+
resnet_groups: int = 32,
|
| 599 |
+
norm_layer: str = "group_norm",
|
| 600 |
+
):
|
| 601 |
+
super().__init__()
|
| 602 |
+
resnet_groups = (
|
| 603 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
self.res_blocks = nn.ModuleList(
|
| 607 |
+
[
|
| 608 |
+
ResnetBlock3D(
|
| 609 |
+
dims=dims,
|
| 610 |
+
in_channels=in_channels,
|
| 611 |
+
out_channels=in_channels,
|
| 612 |
+
eps=resnet_eps,
|
| 613 |
+
groups=resnet_groups,
|
| 614 |
+
dropout=dropout,
|
| 615 |
+
norm_layer=norm_layer,
|
| 616 |
+
)
|
| 617 |
+
for _ in range(num_layers)
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 622 |
+
for resnet in self.res_blocks:
|
| 623 |
+
hidden_states = resnet(hidden_states)
|
| 624 |
+
|
| 625 |
+
return hidden_states
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class UpDecoderBlock3D(nn.Module):
|
| 629 |
+
def __init__(
|
| 630 |
+
self,
|
| 631 |
+
dims: Union[int, Tuple[int, int]],
|
| 632 |
+
in_channels: int,
|
| 633 |
+
out_channels: int,
|
| 634 |
+
resolution_idx: Optional[int] = None,
|
| 635 |
+
dropout: float = 0.0,
|
| 636 |
+
num_layers: int = 1,
|
| 637 |
+
resnet_eps: float = 1e-6,
|
| 638 |
+
resnet_groups: int = 32,
|
| 639 |
+
add_upsample: bool = True,
|
| 640 |
+
norm_layer: str = "group_norm",
|
| 641 |
+
):
|
| 642 |
+
super().__init__()
|
| 643 |
+
res_blocks = []
|
| 644 |
+
|
| 645 |
+
for i in range(num_layers):
|
| 646 |
+
input_channels = in_channels if i == 0 else out_channels
|
| 647 |
+
|
| 648 |
+
res_blocks.append(
|
| 649 |
+
ResnetBlock3D(
|
| 650 |
+
dims=dims,
|
| 651 |
+
in_channels=input_channels,
|
| 652 |
+
out_channels=out_channels,
|
| 653 |
+
eps=resnet_eps,
|
| 654 |
+
groups=resnet_groups,
|
| 655 |
+
dropout=dropout,
|
| 656 |
+
norm_layer=norm_layer,
|
| 657 |
+
)
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 661 |
+
|
| 662 |
+
if add_upsample:
|
| 663 |
+
self.upsample = Upsample3D(
|
| 664 |
+
dims=dims, channels=out_channels, out_channels=out_channels
|
| 665 |
+
)
|
| 666 |
+
else:
|
| 667 |
+
self.upsample = Identity()
|
| 668 |
+
|
| 669 |
+
self.resolution_idx = resolution_idx
|
| 670 |
+
|
| 671 |
+
def forward(
|
| 672 |
+
self, hidden_states: torch.FloatTensor, upsample_in_time=True
|
| 673 |
+
) -> torch.FloatTensor:
|
| 674 |
+
for resnet in self.res_blocks:
|
| 675 |
+
hidden_states = resnet(hidden_states)
|
| 676 |
+
|
| 677 |
+
hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
|
| 678 |
+
|
| 679 |
+
return hidden_states
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class ResnetBlock3D(nn.Module):
|
| 683 |
+
r"""
|
| 684 |
+
A Resnet block.
|
| 685 |
+
|
| 686 |
+
Parameters:
|
| 687 |
+
in_channels (`int`): The number of channels in the input.
|
| 688 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 689 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 690 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 691 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 692 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
dims: Union[int, Tuple[int, int]],
|
| 698 |
+
in_channels: int,
|
| 699 |
+
out_channels: Optional[int] = None,
|
| 700 |
+
conv_shortcut: bool = False,
|
| 701 |
+
dropout: float = 0.0,
|
| 702 |
+
groups: int = 32,
|
| 703 |
+
eps: float = 1e-6,
|
| 704 |
+
norm_layer: str = "group_norm",
|
| 705 |
+
):
|
| 706 |
+
super().__init__()
|
| 707 |
+
self.in_channels = in_channels
|
| 708 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 709 |
+
self.out_channels = out_channels
|
| 710 |
+
self.use_conv_shortcut = conv_shortcut
|
| 711 |
+
|
| 712 |
+
if norm_layer == "group_norm":
|
| 713 |
+
self.norm1 = torch.nn.GroupNorm(
|
| 714 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 715 |
+
)
|
| 716 |
+
elif norm_layer == "pixel_norm":
|
| 717 |
+
self.norm1 = PixelNorm()
|
| 718 |
+
|
| 719 |
+
self.non_linearity = nn.SiLU()
|
| 720 |
+
|
| 721 |
+
self.conv1 = make_conv_nd(
|
| 722 |
+
dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if norm_layer == "group_norm":
|
| 726 |
+
self.norm2 = torch.nn.GroupNorm(
|
| 727 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 728 |
+
)
|
| 729 |
+
elif norm_layer == "pixel_norm":
|
| 730 |
+
self.norm2 = PixelNorm()
|
| 731 |
+
|
| 732 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 733 |
+
|
| 734 |
+
self.conv2 = make_conv_nd(
|
| 735 |
+
dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
self.conv_shortcut = (
|
| 739 |
+
make_linear_nd(
|
| 740 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 741 |
+
)
|
| 742 |
+
if in_channels != out_channels
|
| 743 |
+
else nn.Identity()
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def forward(
|
| 747 |
+
self,
|
| 748 |
+
input_tensor: torch.FloatTensor,
|
| 749 |
+
) -> torch.FloatTensor:
|
| 750 |
+
hidden_states = input_tensor
|
| 751 |
+
|
| 752 |
+
hidden_states = self.norm1(hidden_states)
|
| 753 |
+
|
| 754 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 755 |
+
|
| 756 |
+
hidden_states = self.conv1(hidden_states)
|
| 757 |
+
|
| 758 |
+
hidden_states = self.norm2(hidden_states)
|
| 759 |
+
|
| 760 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 761 |
+
|
| 762 |
+
hidden_states = self.dropout(hidden_states)
|
| 763 |
+
|
| 764 |
+
hidden_states = self.conv2(hidden_states)
|
| 765 |
+
|
| 766 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 767 |
+
|
| 768 |
+
output_tensor = input_tensor + hidden_states
|
| 769 |
+
|
| 770 |
+
return output_tensor
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class Downsample3D(nn.Module):
|
| 774 |
+
def __init__(
|
| 775 |
+
self,
|
| 776 |
+
dims,
|
| 777 |
+
in_channels: int,
|
| 778 |
+
out_channels: int,
|
| 779 |
+
kernel_size: int = 3,
|
| 780 |
+
padding: int = 1,
|
| 781 |
+
):
|
| 782 |
+
super().__init__()
|
| 783 |
+
stride: int = 2
|
| 784 |
+
self.padding = padding
|
| 785 |
+
self.in_channels = in_channels
|
| 786 |
+
self.dims = dims
|
| 787 |
+
self.conv = make_conv_nd(
|
| 788 |
+
dims=dims,
|
| 789 |
+
in_channels=in_channels,
|
| 790 |
+
out_channels=out_channels,
|
| 791 |
+
kernel_size=kernel_size,
|
| 792 |
+
stride=stride,
|
| 793 |
+
padding=padding,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
def forward(self, x, downsample_in_time=True):
|
| 797 |
+
conv = self.conv
|
| 798 |
+
if self.padding == 0:
|
| 799 |
+
if self.dims == 2:
|
| 800 |
+
padding = (0, 1, 0, 1)
|
| 801 |
+
else:
|
| 802 |
+
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
| 803 |
+
|
| 804 |
+
x = functional.pad(x, padding, mode="constant", value=0)
|
| 805 |
+
|
| 806 |
+
if self.dims == (2, 1) and not downsample_in_time:
|
| 807 |
+
return conv(x, skip_time_conv=True)
|
| 808 |
+
|
| 809 |
+
return conv(x)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
class Upsample3D(nn.Module):
|
| 813 |
+
"""
|
| 814 |
+
An upsampling layer for 3D tensors of shape (B, C, D, H, W).
|
| 815 |
+
|
| 816 |
+
:param channels: channels in the inputs and outputs.
|
| 817 |
+
"""
|
| 818 |
+
|
| 819 |
+
def __init__(self, dims, channels, out_channels=None):
|
| 820 |
+
super().__init__()
|
| 821 |
+
self.dims = dims
|
| 822 |
+
self.channels = channels
|
| 823 |
+
self.out_channels = out_channels or channels
|
| 824 |
+
self.conv = make_conv_nd(
|
| 825 |
+
dims, channels, out_channels, kernel_size=3, padding=1, bias=True
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
def forward(self, x, upsample_in_time):
|
| 829 |
+
if self.dims == 2:
|
| 830 |
+
x = functional.interpolate(
|
| 831 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 832 |
+
)
|
| 833 |
+
else:
|
| 834 |
+
time_scale_factor = 2 if upsample_in_time else 1
|
| 835 |
+
# print("before:", x.shape)
|
| 836 |
+
b, c, d, h, w = x.shape
|
| 837 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 838 |
+
# height and width interpolate
|
| 839 |
+
x = functional.interpolate(
|
| 840 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 841 |
+
)
|
| 842 |
+
_, _, h, w = x.shape
|
| 843 |
+
|
| 844 |
+
if not upsample_in_time and self.dims == (2, 1):
|
| 845 |
+
x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
|
| 846 |
+
return self.conv(x, skip_time_conv=True)
|
| 847 |
+
|
| 848 |
+
# Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
|
| 849 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
|
| 850 |
+
|
| 851 |
+
# (b h w) c 1 d
|
| 852 |
+
new_d = x.shape[-1] * time_scale_factor
|
| 853 |
+
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
| 854 |
+
# (b h w) c 1 new_d
|
| 855 |
+
x = rearrange(
|
| 856 |
+
x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
|
| 857 |
+
)
|
| 858 |
+
# b c d h w
|
| 859 |
+
|
| 860 |
+
# x = functional.interpolate(
|
| 861 |
+
# x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 862 |
+
# )
|
| 863 |
+
# print("after:", x.shape)
|
| 864 |
+
|
| 865 |
+
return self.conv(x)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 869 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 870 |
+
return x
|
| 871 |
+
if x.dim() == 4:
|
| 872 |
+
x = rearrange(
|
| 873 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 874 |
+
)
|
| 875 |
+
elif x.dim() == 5:
|
| 876 |
+
x = rearrange(
|
| 877 |
+
x,
|
| 878 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 879 |
+
p=patch_size_t,
|
| 880 |
+
q=patch_size_hw,
|
| 881 |
+
r=patch_size_hw,
|
| 882 |
+
)
|
| 883 |
+
else:
|
| 884 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 885 |
+
|
| 886 |
+
if (
|
| 887 |
+
(x.dim() == 5)
|
| 888 |
+
and (patch_size_hw > patch_size_t)
|
| 889 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 890 |
+
):
|
| 891 |
+
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
| 892 |
+
padding_zeros = torch.zeros(
|
| 893 |
+
x.shape[0],
|
| 894 |
+
channels_to_pad,
|
| 895 |
+
x.shape[2],
|
| 896 |
+
x.shape[3],
|
| 897 |
+
x.shape[4],
|
| 898 |
+
device=x.device,
|
| 899 |
+
dtype=x.dtype,
|
| 900 |
+
)
|
| 901 |
+
x = torch.cat([padding_zeros, x], dim=1)
|
| 902 |
+
|
| 903 |
+
return x
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 907 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 908 |
+
return x
|
| 909 |
+
|
| 910 |
+
if (
|
| 911 |
+
(x.dim() == 5)
|
| 912 |
+
and (patch_size_hw > patch_size_t)
|
| 913 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 914 |
+
):
|
| 915 |
+
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
| 916 |
+
x = x[:, :channels_to_keep, :, :, :]
|
| 917 |
+
|
| 918 |
+
if x.dim() == 4:
|
| 919 |
+
x = rearrange(
|
| 920 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 921 |
+
)
|
| 922 |
+
elif x.dim() == 5:
|
| 923 |
+
x = rearrange(
|
| 924 |
+
x,
|
| 925 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 926 |
+
p=patch_size_t,
|
| 927 |
+
q=patch_size_hw,
|
| 928 |
+
r=patch_size_hw,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
return x
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def create_video_autoencoder_config(
|
| 935 |
+
latent_channels: int = 4,
|
| 936 |
+
):
|
| 937 |
+
config = {
|
| 938 |
+
"_class_name": "VideoAutoencoder",
|
| 939 |
+
"dims": (
|
| 940 |
+
2,
|
| 941 |
+
1,
|
| 942 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 943 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 944 |
+
"out_channels": 3, # Number of output color channels
|
| 945 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 946 |
+
"block_out_channels": [
|
| 947 |
+
128,
|
| 948 |
+
256,
|
| 949 |
+
512,
|
| 950 |
+
512,
|
| 951 |
+
], # Number of output channels of each encoder / decoder inner block
|
| 952 |
+
"patch_size": 1,
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
return config
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
def create_video_autoencoder_pathify4x4x4_config(
|
| 959 |
+
latent_channels: int = 4,
|
| 960 |
+
):
|
| 961 |
+
config = {
|
| 962 |
+
"_class_name": "VideoAutoencoder",
|
| 963 |
+
"dims": (
|
| 964 |
+
2,
|
| 965 |
+
1,
|
| 966 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 967 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 968 |
+
"out_channels": 3, # Number of output color channels
|
| 969 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 970 |
+
"block_out_channels": [512]
|
| 971 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 972 |
+
"patch_size": 4,
|
| 973 |
+
"latent_log_var": "uniform",
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
return config
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def create_video_autoencoder_pathify4x4_config(
|
| 980 |
+
latent_channels: int = 4,
|
| 981 |
+
):
|
| 982 |
+
config = {
|
| 983 |
+
"_class_name": "VideoAutoencoder",
|
| 984 |
+
"dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 985 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 986 |
+
"out_channels": 3, # Number of output color channels
|
| 987 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 988 |
+
"block_out_channels": [512]
|
| 989 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 990 |
+
"patch_size": 4,
|
| 991 |
+
"norm_layer": "pixel_norm",
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
return config
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def test_vae_patchify_unpatchify():
|
| 998 |
+
import torch
|
| 999 |
+
|
| 1000 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
| 1001 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
| 1002 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
| 1003 |
+
assert torch.allclose(x, x_unpatched)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def demo_video_autoencoder_forward_backward():
|
| 1007 |
+
# Configuration for the VideoAutoencoder
|
| 1008 |
+
config = create_video_autoencoder_pathify4x4x4_config()
|
| 1009 |
+
|
| 1010 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
| 1011 |
+
video_autoencoder = VideoAutoencoder.from_config(config)
|
| 1012 |
+
|
| 1013 |
+
print(video_autoencoder)
|
| 1014 |
+
|
| 1015 |
+
# Print the total number of parameters in the video autoencoder
|
| 1016 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
| 1017 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
| 1018 |
+
|
| 1019 |
+
# Create a mock input tensor simulating a batch of videos
|
| 1020 |
+
# Shape: (batch_size, channels, depth, height, width)
|
| 1021 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
| 1022 |
+
input_videos = torch.randn(2, 3, 8, 64, 64)
|
| 1023 |
+
|
| 1024 |
+
# Forward pass: encode and decode the input videos
|
| 1025 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 1026 |
+
print(f"input shape={input_videos.shape}")
|
| 1027 |
+
print(f"latent shape={latent.shape}")
|
| 1028 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 1029 |
+
latent, target_shape=input_videos.shape
|
| 1030 |
+
).sample
|
| 1031 |
+
|
| 1032 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 1033 |
+
|
| 1034 |
+
# Calculate the loss (e.g., mean squared error)
|
| 1035 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
| 1036 |
+
|
| 1037 |
+
# Perform backward pass
|
| 1038 |
+
loss.backward()
|
| 1039 |
+
|
| 1040 |
+
print(f"Demo completed with loss: {loss.item()}")
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
| 1044 |
+
if __name__ == "__main__":
|
| 1045 |
+
demo_video_autoencoder_forward_backward()
|
ltx_video/models/transformers/__init__.py
ADDED
|
File without changes
|
ltx_video/models/transformers/attention.py
ADDED
|
@@ -0,0 +1,1264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from importlib import import_module
|
| 3 |
+
from typing import Any, Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
| 8 |
+
from diffusers.models.attention import _chunked_feed_forward
|
| 9 |
+
from diffusers.models.attention_processor import (
|
| 10 |
+
LoRAAttnAddedKVProcessor,
|
| 11 |
+
LoRAAttnProcessor,
|
| 12 |
+
LoRAAttnProcessor2_0,
|
| 13 |
+
LoRAXFormersAttnProcessor,
|
| 14 |
+
SpatialNorm,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 17 |
+
from diffusers.models.normalization import RMSNorm
|
| 18 |
+
from diffusers.utils import deprecate, logging
|
| 19 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from torch_xla.experimental.custom_kernel import flash_attention
|
| 27 |
+
except ImportError:
|
| 28 |
+
# workaround for automatic tests. Currently this function is manually patched
|
| 29 |
+
# to the torch_xla lib on setup of container
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@maybe_allow_in_graph
|
| 38 |
+
class BasicTransformerBlock(nn.Module):
|
| 39 |
+
r"""
|
| 40 |
+
A basic Transformer block.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
dim (`int`): The number of channels in the input and output.
|
| 44 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 45 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 46 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 47 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 48 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 49 |
+
num_embeds_ada_norm (:
|
| 50 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 51 |
+
attention_bias (:
|
| 52 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 53 |
+
only_cross_attention (`bool`, *optional*):
|
| 54 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 55 |
+
double_self_attention (`bool`, *optional*):
|
| 56 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 57 |
+
upcast_attention (`bool`, *optional*):
|
| 58 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
| 59 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 61 |
+
qk_norm (`str`, *optional*, defaults to None):
|
| 62 |
+
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
| 63 |
+
adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
|
| 64 |
+
The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
|
| 65 |
+
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
|
| 66 |
+
The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
| 67 |
+
final_dropout (`bool` *optional*, defaults to False):
|
| 68 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 69 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
| 70 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
| 71 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
| 72 |
+
The type of positional embeddings to apply to.
|
| 73 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
| 74 |
+
The maximum number of positional embeddings to apply.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim: int,
|
| 80 |
+
num_attention_heads: int,
|
| 81 |
+
attention_head_dim: int,
|
| 82 |
+
dropout=0.0,
|
| 83 |
+
cross_attention_dim: Optional[int] = None,
|
| 84 |
+
activation_fn: str = "geglu",
|
| 85 |
+
num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
|
| 86 |
+
attention_bias: bool = False,
|
| 87 |
+
only_cross_attention: bool = False,
|
| 88 |
+
double_self_attention: bool = False,
|
| 89 |
+
upcast_attention: bool = False,
|
| 90 |
+
norm_elementwise_affine: bool = True,
|
| 91 |
+
adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
|
| 92 |
+
standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
|
| 93 |
+
norm_eps: float = 1e-5,
|
| 94 |
+
qk_norm: Optional[str] = None,
|
| 95 |
+
final_dropout: bool = False,
|
| 96 |
+
attention_type: str = "default", # pylint: disable=unused-argument
|
| 97 |
+
ff_inner_dim: Optional[int] = None,
|
| 98 |
+
ff_bias: bool = True,
|
| 99 |
+
attention_out_bias: bool = True,
|
| 100 |
+
use_tpu_flash_attention: bool = False,
|
| 101 |
+
use_rope: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.only_cross_attention = only_cross_attention
|
| 105 |
+
self.use_tpu_flash_attention = use_tpu_flash_attention
|
| 106 |
+
self.adaptive_norm = adaptive_norm
|
| 107 |
+
|
| 108 |
+
assert standardization_norm in ["layer_norm", "rms_norm"]
|
| 109 |
+
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
| 110 |
+
|
| 111 |
+
make_norm_layer = (
|
| 112 |
+
nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 116 |
+
# 1. Self-Attn
|
| 117 |
+
self.norm1 = make_norm_layer(
|
| 118 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.attn1 = Attention(
|
| 122 |
+
query_dim=dim,
|
| 123 |
+
heads=num_attention_heads,
|
| 124 |
+
dim_head=attention_head_dim,
|
| 125 |
+
dropout=dropout,
|
| 126 |
+
bias=attention_bias,
|
| 127 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 128 |
+
upcast_attention=upcast_attention,
|
| 129 |
+
out_bias=attention_out_bias,
|
| 130 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 131 |
+
qk_norm=qk_norm,
|
| 132 |
+
use_rope=use_rope,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# 2. Cross-Attn
|
| 136 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 137 |
+
self.attn2 = Attention(
|
| 138 |
+
query_dim=dim,
|
| 139 |
+
cross_attention_dim=(
|
| 140 |
+
cross_attention_dim if not double_self_attention else None
|
| 141 |
+
),
|
| 142 |
+
heads=num_attention_heads,
|
| 143 |
+
dim_head=attention_head_dim,
|
| 144 |
+
dropout=dropout,
|
| 145 |
+
bias=attention_bias,
|
| 146 |
+
upcast_attention=upcast_attention,
|
| 147 |
+
out_bias=attention_out_bias,
|
| 148 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 149 |
+
qk_norm=qk_norm,
|
| 150 |
+
use_rope=use_rope,
|
| 151 |
+
) # is self-attn if encoder_hidden_states is none
|
| 152 |
+
|
| 153 |
+
if adaptive_norm == "none":
|
| 154 |
+
self.attn2_norm = make_norm_layer(
|
| 155 |
+
dim, norm_eps, norm_elementwise_affine
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
self.attn2 = None
|
| 159 |
+
self.attn2_norm = None
|
| 160 |
+
|
| 161 |
+
self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
|
| 162 |
+
|
| 163 |
+
# 3. Feed-forward
|
| 164 |
+
self.ff = FeedForward(
|
| 165 |
+
dim,
|
| 166 |
+
dropout=dropout,
|
| 167 |
+
activation_fn=activation_fn,
|
| 168 |
+
final_dropout=final_dropout,
|
| 169 |
+
inner_dim=ff_inner_dim,
|
| 170 |
+
bias=ff_bias,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# 5. Scale-shift for PixArt-Alpha.
|
| 174 |
+
if adaptive_norm != "none":
|
| 175 |
+
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
| 176 |
+
self.scale_shift_table = nn.Parameter(
|
| 177 |
+
torch.randn(num_ada_params, dim) / dim**0.5
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# let chunk size default to None
|
| 181 |
+
self._chunk_size = None
|
| 182 |
+
self._chunk_dim = 0
|
| 183 |
+
|
| 184 |
+
def set_use_tpu_flash_attention(self):
|
| 185 |
+
r"""
|
| 186 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 187 |
+
attention kernel.
|
| 188 |
+
"""
|
| 189 |
+
self.use_tpu_flash_attention = True
|
| 190 |
+
self.attn1.set_use_tpu_flash_attention()
|
| 191 |
+
self.attn2.set_use_tpu_flash_attention()
|
| 192 |
+
|
| 193 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 194 |
+
# Sets chunk feed-forward
|
| 195 |
+
self._chunk_size = chunk_size
|
| 196 |
+
self._chunk_dim = dim
|
| 197 |
+
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
hidden_states: torch.FloatTensor,
|
| 201 |
+
freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
|
| 202 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 203 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 204 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 205 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 206 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 207 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 208 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 209 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 210 |
+
) -> torch.FloatTensor:
|
| 211 |
+
if cross_attention_kwargs is not None:
|
| 212 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
| 213 |
+
logger.warning(
|
| 214 |
+
"Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 218 |
+
# 0. Self-Attention
|
| 219 |
+
batch_size = hidden_states.shape[0]
|
| 220 |
+
|
| 221 |
+
original_hidden_states = hidden_states
|
| 222 |
+
|
| 223 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 224 |
+
|
| 225 |
+
# Apply ada_norm_single
|
| 226 |
+
if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
|
| 227 |
+
assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
|
| 228 |
+
num_ada_params = self.scale_shift_table.shape[0]
|
| 229 |
+
ada_values = self.scale_shift_table[None, None] + timestep.reshape(
|
| 230 |
+
batch_size, timestep.shape[1], num_ada_params, -1
|
| 231 |
+
)
|
| 232 |
+
if self.adaptive_norm == "single_scale_shift":
|
| 233 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 234 |
+
ada_values.unbind(dim=2)
|
| 235 |
+
)
|
| 236 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 237 |
+
else:
|
| 238 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
| 239 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa)
|
| 240 |
+
elif self.adaptive_norm == "none":
|
| 241 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 244 |
+
|
| 245 |
+
norm_hidden_states = norm_hidden_states.squeeze(
|
| 246 |
+
1
|
| 247 |
+
) # TODO: Check if this is needed
|
| 248 |
+
|
| 249 |
+
# 1. Prepare GLIGEN inputs
|
| 250 |
+
cross_attention_kwargs = (
|
| 251 |
+
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
attn_output = self.attn1(
|
| 255 |
+
norm_hidden_states,
|
| 256 |
+
freqs_cis=freqs_cis,
|
| 257 |
+
encoder_hidden_states=(
|
| 258 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 259 |
+
),
|
| 260 |
+
attention_mask=attention_mask,
|
| 261 |
+
skip_layer_mask=skip_layer_mask,
|
| 262 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 263 |
+
**cross_attention_kwargs,
|
| 264 |
+
)
|
| 265 |
+
if gate_msa is not None:
|
| 266 |
+
attn_output = gate_msa * attn_output
|
| 267 |
+
|
| 268 |
+
hidden_states = attn_output + hidden_states
|
| 269 |
+
if hidden_states.ndim == 4:
|
| 270 |
+
hidden_states = hidden_states.squeeze(1)
|
| 271 |
+
|
| 272 |
+
# 3. Cross-Attention
|
| 273 |
+
if self.attn2 is not None:
|
| 274 |
+
if self.adaptive_norm == "none":
|
| 275 |
+
attn_input = self.attn2_norm(hidden_states)
|
| 276 |
+
else:
|
| 277 |
+
attn_input = hidden_states
|
| 278 |
+
attn_output = self.attn2(
|
| 279 |
+
attn_input,
|
| 280 |
+
freqs_cis=freqs_cis,
|
| 281 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 282 |
+
attention_mask=encoder_attention_mask,
|
| 283 |
+
**cross_attention_kwargs,
|
| 284 |
+
)
|
| 285 |
+
hidden_states = attn_output + hidden_states
|
| 286 |
+
|
| 287 |
+
# 4. Feed-forward
|
| 288 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 289 |
+
if self.adaptive_norm == "single_scale_shift":
|
| 290 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 291 |
+
elif self.adaptive_norm == "single_scale":
|
| 292 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
|
| 293 |
+
elif self.adaptive_norm == "none":
|
| 294 |
+
pass
|
| 295 |
+
else:
|
| 296 |
+
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 297 |
+
|
| 298 |
+
if self._chunk_size is not None:
|
| 299 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 300 |
+
ff_output = _chunked_feed_forward(
|
| 301 |
+
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
|
| 302 |
+
)
|
| 303 |
+
else:
|
| 304 |
+
ff_output = self.ff(norm_hidden_states)
|
| 305 |
+
if gate_mlp is not None:
|
| 306 |
+
ff_output = gate_mlp * ff_output
|
| 307 |
+
|
| 308 |
+
hidden_states = ff_output + hidden_states
|
| 309 |
+
if hidden_states.ndim == 4:
|
| 310 |
+
hidden_states = hidden_states.squeeze(1)
|
| 311 |
+
|
| 312 |
+
if (
|
| 313 |
+
skip_layer_mask is not None
|
| 314 |
+
and skip_layer_strategy == SkipLayerStrategy.TransformerBlock
|
| 315 |
+
):
|
| 316 |
+
skip_layer_mask = skip_layer_mask.view(-1, 1, 1)
|
| 317 |
+
hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (
|
| 318 |
+
1.0 - skip_layer_mask
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
return hidden_states
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@maybe_allow_in_graph
|
| 325 |
+
class Attention(nn.Module):
|
| 326 |
+
r"""
|
| 327 |
+
A cross attention layer.
|
| 328 |
+
|
| 329 |
+
Parameters:
|
| 330 |
+
query_dim (`int`):
|
| 331 |
+
The number of channels in the query.
|
| 332 |
+
cross_attention_dim (`int`, *optional*):
|
| 333 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 334 |
+
heads (`int`, *optional*, defaults to 8):
|
| 335 |
+
The number of heads to use for multi-head attention.
|
| 336 |
+
dim_head (`int`, *optional*, defaults to 64):
|
| 337 |
+
The number of channels in each head.
|
| 338 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 339 |
+
The dropout probability to use.
|
| 340 |
+
bias (`bool`, *optional*, defaults to False):
|
| 341 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 342 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
| 343 |
+
Set to `True` to upcast the attention computation to `float32`.
|
| 344 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
| 345 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
| 346 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
| 347 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
| 348 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
| 349 |
+
The number of groups to use for the group norm in the cross attention.
|
| 350 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 351 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 352 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
| 353 |
+
The number of groups to use for the group norm in the attention.
|
| 354 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
| 355 |
+
The number of channels to use for the spatial normalization.
|
| 356 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
| 357 |
+
Set to `True` to use a bias in the output linear layer.
|
| 358 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
| 359 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
| 360 |
+
qk_norm (`str`, *optional*, defaults to None):
|
| 361 |
+
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
| 362 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
| 363 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
| 364 |
+
`added_kv_proj_dim` is not `None`.
|
| 365 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
| 366 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
| 367 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
| 368 |
+
A factor to rescale the output by dividing it with this value.
|
| 369 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
| 370 |
+
Set to `True` to add the residual connection to the output.
|
| 371 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
| 372 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
| 373 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
| 374 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
| 375 |
+
`AttnProcessor` otherwise.
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
query_dim: int,
|
| 381 |
+
cross_attention_dim: Optional[int] = None,
|
| 382 |
+
heads: int = 8,
|
| 383 |
+
dim_head: int = 64,
|
| 384 |
+
dropout: float = 0.0,
|
| 385 |
+
bias: bool = False,
|
| 386 |
+
upcast_attention: bool = False,
|
| 387 |
+
upcast_softmax: bool = False,
|
| 388 |
+
cross_attention_norm: Optional[str] = None,
|
| 389 |
+
cross_attention_norm_num_groups: int = 32,
|
| 390 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 391 |
+
norm_num_groups: Optional[int] = None,
|
| 392 |
+
spatial_norm_dim: Optional[int] = None,
|
| 393 |
+
out_bias: bool = True,
|
| 394 |
+
scale_qk: bool = True,
|
| 395 |
+
qk_norm: Optional[str] = None,
|
| 396 |
+
only_cross_attention: bool = False,
|
| 397 |
+
eps: float = 1e-5,
|
| 398 |
+
rescale_output_factor: float = 1.0,
|
| 399 |
+
residual_connection: bool = False,
|
| 400 |
+
_from_deprecated_attn_block: bool = False,
|
| 401 |
+
processor: Optional["AttnProcessor"] = None,
|
| 402 |
+
out_dim: int = None,
|
| 403 |
+
use_tpu_flash_attention: bool = False,
|
| 404 |
+
use_rope: bool = False,
|
| 405 |
+
):
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 408 |
+
self.query_dim = query_dim
|
| 409 |
+
self.use_bias = bias
|
| 410 |
+
self.is_cross_attention = cross_attention_dim is not None
|
| 411 |
+
self.cross_attention_dim = (
|
| 412 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 413 |
+
)
|
| 414 |
+
self.upcast_attention = upcast_attention
|
| 415 |
+
self.upcast_softmax = upcast_softmax
|
| 416 |
+
self.rescale_output_factor = rescale_output_factor
|
| 417 |
+
self.residual_connection = residual_connection
|
| 418 |
+
self.dropout = dropout
|
| 419 |
+
self.fused_projections = False
|
| 420 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 421 |
+
self.use_tpu_flash_attention = use_tpu_flash_attention
|
| 422 |
+
self.use_rope = use_rope
|
| 423 |
+
|
| 424 |
+
# we make use of this private variable to know whether this class is loaded
|
| 425 |
+
# with an deprecated state dict so that we can convert it on the fly
|
| 426 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
| 427 |
+
|
| 428 |
+
self.scale_qk = scale_qk
|
| 429 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 430 |
+
|
| 431 |
+
if qk_norm is None:
|
| 432 |
+
self.q_norm = nn.Identity()
|
| 433 |
+
self.k_norm = nn.Identity()
|
| 434 |
+
elif qk_norm == "rms_norm":
|
| 435 |
+
self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
|
| 436 |
+
self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
|
| 437 |
+
elif qk_norm == "layer_norm":
|
| 438 |
+
self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
|
| 439 |
+
self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
|
| 442 |
+
|
| 443 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 444 |
+
# for slice_size > 0 the attention score computation
|
| 445 |
+
# is split across the batch axis to save memory
|
| 446 |
+
# You can set slice_size with `set_attention_slice`
|
| 447 |
+
self.sliceable_head_dim = heads
|
| 448 |
+
|
| 449 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 450 |
+
self.only_cross_attention = only_cross_attention
|
| 451 |
+
|
| 452 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 453 |
+
raise ValueError(
|
| 454 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if norm_num_groups is not None:
|
| 458 |
+
self.group_norm = nn.GroupNorm(
|
| 459 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
self.group_norm = None
|
| 463 |
+
|
| 464 |
+
if spatial_norm_dim is not None:
|
| 465 |
+
self.spatial_norm = SpatialNorm(
|
| 466 |
+
f_channels=query_dim, zq_channels=spatial_norm_dim
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
self.spatial_norm = None
|
| 470 |
+
|
| 471 |
+
if cross_attention_norm is None:
|
| 472 |
+
self.norm_cross = None
|
| 473 |
+
elif cross_attention_norm == "layer_norm":
|
| 474 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
| 475 |
+
elif cross_attention_norm == "group_norm":
|
| 476 |
+
if self.added_kv_proj_dim is not None:
|
| 477 |
+
# The given `encoder_hidden_states` are initially of shape
|
| 478 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
| 479 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
| 480 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
| 481 |
+
# the number of channels for the group norm.
|
| 482 |
+
norm_cross_num_channels = added_kv_proj_dim
|
| 483 |
+
else:
|
| 484 |
+
norm_cross_num_channels = self.cross_attention_dim
|
| 485 |
+
|
| 486 |
+
self.norm_cross = nn.GroupNorm(
|
| 487 |
+
num_channels=norm_cross_num_channels,
|
| 488 |
+
num_groups=cross_attention_norm_num_groups,
|
| 489 |
+
eps=1e-5,
|
| 490 |
+
affine=True,
|
| 491 |
+
)
|
| 492 |
+
else:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
linear_cls = nn.Linear
|
| 498 |
+
|
| 499 |
+
self.linear_cls = linear_cls
|
| 500 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
| 501 |
+
|
| 502 |
+
if not self.only_cross_attention:
|
| 503 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 504 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 505 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 506 |
+
else:
|
| 507 |
+
self.to_k = None
|
| 508 |
+
self.to_v = None
|
| 509 |
+
|
| 510 |
+
if self.added_kv_proj_dim is not None:
|
| 511 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 512 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 513 |
+
|
| 514 |
+
self.to_out = nn.ModuleList([])
|
| 515 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
| 516 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 517 |
+
|
| 518 |
+
# set attention processor
|
| 519 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 520 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 521 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 522 |
+
if processor is None:
|
| 523 |
+
processor = AttnProcessor2_0()
|
| 524 |
+
self.set_processor(processor)
|
| 525 |
+
|
| 526 |
+
def set_use_tpu_flash_attention(self):
|
| 527 |
+
r"""
|
| 528 |
+
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
| 529 |
+
"""
|
| 530 |
+
self.use_tpu_flash_attention = True
|
| 531 |
+
|
| 532 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 533 |
+
r"""
|
| 534 |
+
Set the attention processor to use.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
processor (`AttnProcessor`):
|
| 538 |
+
The attention processor to use.
|
| 539 |
+
"""
|
| 540 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 541 |
+
# pop `processor` from `self._modules`
|
| 542 |
+
if (
|
| 543 |
+
hasattr(self, "processor")
|
| 544 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 545 |
+
and not isinstance(processor, torch.nn.Module)
|
| 546 |
+
):
|
| 547 |
+
logger.info(
|
| 548 |
+
f"You are removing possibly trained weights of {self.processor} with {processor}"
|
| 549 |
+
)
|
| 550 |
+
self._modules.pop("processor")
|
| 551 |
+
|
| 552 |
+
self.processor = processor
|
| 553 |
+
|
| 554 |
+
def get_processor(
|
| 555 |
+
self, return_deprecated_lora: bool = False
|
| 556 |
+
) -> "AttentionProcessor": # noqa: F821
|
| 557 |
+
r"""
|
| 558 |
+
Get the attention processor in use.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 562 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
"AttentionProcessor": The attention processor in use.
|
| 566 |
+
"""
|
| 567 |
+
if not return_deprecated_lora:
|
| 568 |
+
return self.processor
|
| 569 |
+
|
| 570 |
+
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
| 571 |
+
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
| 572 |
+
# with PEFT is completed.
|
| 573 |
+
is_lora_activated = {
|
| 574 |
+
name: module.lora_layer is not None
|
| 575 |
+
for name, module in self.named_modules()
|
| 576 |
+
if hasattr(module, "lora_layer")
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
# 1. if no layer has a LoRA activated we can return the processor as usual
|
| 580 |
+
if not any(is_lora_activated.values()):
|
| 581 |
+
return self.processor
|
| 582 |
+
|
| 583 |
+
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
| 584 |
+
is_lora_activated.pop("add_k_proj", None)
|
| 585 |
+
is_lora_activated.pop("add_v_proj", None)
|
| 586 |
+
# 2. else it is not posssible that only some layers have LoRA activated
|
| 587 |
+
if not all(is_lora_activated.values()):
|
| 588 |
+
raise ValueError(
|
| 589 |
+
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
| 593 |
+
non_lora_processor_cls_name = self.processor.__class__.__name__
|
| 594 |
+
lora_processor_cls = getattr(
|
| 595 |
+
import_module(__name__), "LoRA" + non_lora_processor_cls_name
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
hidden_size = self.inner_dim
|
| 599 |
+
|
| 600 |
+
# now create a LoRA attention processor from the LoRA layers
|
| 601 |
+
if lora_processor_cls in [
|
| 602 |
+
LoRAAttnProcessor,
|
| 603 |
+
LoRAAttnProcessor2_0,
|
| 604 |
+
LoRAXFormersAttnProcessor,
|
| 605 |
+
]:
|
| 606 |
+
kwargs = {
|
| 607 |
+
"cross_attention_dim": self.cross_attention_dim,
|
| 608 |
+
"rank": self.to_q.lora_layer.rank,
|
| 609 |
+
"network_alpha": self.to_q.lora_layer.network_alpha,
|
| 610 |
+
"q_rank": self.to_q.lora_layer.rank,
|
| 611 |
+
"q_hidden_size": self.to_q.lora_layer.out_features,
|
| 612 |
+
"k_rank": self.to_k.lora_layer.rank,
|
| 613 |
+
"k_hidden_size": self.to_k.lora_layer.out_features,
|
| 614 |
+
"v_rank": self.to_v.lora_layer.rank,
|
| 615 |
+
"v_hidden_size": self.to_v.lora_layer.out_features,
|
| 616 |
+
"out_rank": self.to_out[0].lora_layer.rank,
|
| 617 |
+
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
if hasattr(self.processor, "attention_op"):
|
| 621 |
+
kwargs["attention_op"] = self.processor.attention_op
|
| 622 |
+
|
| 623 |
+
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
| 624 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 625 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 626 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 627 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 628 |
+
self.to_out[0].lora_layer.state_dict()
|
| 629 |
+
)
|
| 630 |
+
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
| 631 |
+
lora_processor = lora_processor_cls(
|
| 632 |
+
hidden_size,
|
| 633 |
+
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
| 634 |
+
rank=self.to_q.lora_layer.rank,
|
| 635 |
+
network_alpha=self.to_q.lora_layer.network_alpha,
|
| 636 |
+
)
|
| 637 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 638 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 639 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 640 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 641 |
+
self.to_out[0].lora_layer.state_dict()
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
# only save if used
|
| 645 |
+
if self.add_k_proj.lora_layer is not None:
|
| 646 |
+
lora_processor.add_k_proj_lora.load_state_dict(
|
| 647 |
+
self.add_k_proj.lora_layer.state_dict()
|
| 648 |
+
)
|
| 649 |
+
lora_processor.add_v_proj_lora.load_state_dict(
|
| 650 |
+
self.add_v_proj.lora_layer.state_dict()
|
| 651 |
+
)
|
| 652 |
+
else:
|
| 653 |
+
lora_processor.add_k_proj_lora = None
|
| 654 |
+
lora_processor.add_v_proj_lora = None
|
| 655 |
+
else:
|
| 656 |
+
raise ValueError(f"{lora_processor_cls} does not exist.")
|
| 657 |
+
|
| 658 |
+
return lora_processor
|
| 659 |
+
|
| 660 |
+
def forward(
|
| 661 |
+
self,
|
| 662 |
+
hidden_states: torch.FloatTensor,
|
| 663 |
+
freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
|
| 664 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 665 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 666 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 667 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 668 |
+
**cross_attention_kwargs,
|
| 669 |
+
) -> torch.Tensor:
|
| 670 |
+
r"""
|
| 671 |
+
The forward method of the `Attention` class.
|
| 672 |
+
|
| 673 |
+
Args:
|
| 674 |
+
hidden_states (`torch.Tensor`):
|
| 675 |
+
The hidden states of the query.
|
| 676 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
| 677 |
+
The hidden states of the encoder.
|
| 678 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 679 |
+
The attention mask to use. If `None`, no mask is applied.
|
| 680 |
+
skip_layer_mask (`torch.Tensor`, *optional*):
|
| 681 |
+
The skip layer mask to use. If `None`, no mask is applied.
|
| 682 |
+
skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`):
|
| 683 |
+
Controls which layers to skip for spatiotemporal guidance.
|
| 684 |
+
**cross_attention_kwargs:
|
| 685 |
+
Additional keyword arguments to pass along to the cross attention.
|
| 686 |
+
|
| 687 |
+
Returns:
|
| 688 |
+
`torch.Tensor`: The output of the attention layer.
|
| 689 |
+
"""
|
| 690 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 691 |
+
# here we simply pass along all tensors to the selected processor class
|
| 692 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 693 |
+
|
| 694 |
+
attn_parameters = set(
|
| 695 |
+
inspect.signature(self.processor.__call__).parameters.keys()
|
| 696 |
+
)
|
| 697 |
+
unused_kwargs = [
|
| 698 |
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
|
| 699 |
+
]
|
| 700 |
+
if len(unused_kwargs) > 0:
|
| 701 |
+
logger.warning(
|
| 702 |
+
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
| 703 |
+
f" {self.processor.__class__.__name__} and will be ignored."
|
| 704 |
+
)
|
| 705 |
+
cross_attention_kwargs = {
|
| 706 |
+
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
return self.processor(
|
| 710 |
+
self,
|
| 711 |
+
hidden_states,
|
| 712 |
+
freqs_cis=freqs_cis,
|
| 713 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 714 |
+
attention_mask=attention_mask,
|
| 715 |
+
skip_layer_mask=skip_layer_mask,
|
| 716 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 717 |
+
**cross_attention_kwargs,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 721 |
+
r"""
|
| 722 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
| 723 |
+
is the number of heads initialized while constructing the `Attention` class.
|
| 724 |
+
|
| 725 |
+
Args:
|
| 726 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
`torch.Tensor`: The reshaped tensor.
|
| 730 |
+
"""
|
| 731 |
+
head_size = self.heads
|
| 732 |
+
batch_size, seq_len, dim = tensor.shape
|
| 733 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 734 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
| 735 |
+
batch_size // head_size, seq_len, dim * head_size
|
| 736 |
+
)
|
| 737 |
+
return tensor
|
| 738 |
+
|
| 739 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
| 740 |
+
r"""
|
| 741 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
| 742 |
+
the number of heads initialized while constructing the `Attention` class.
|
| 743 |
+
|
| 744 |
+
Args:
|
| 745 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 746 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
| 747 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
| 748 |
+
|
| 749 |
+
Returns:
|
| 750 |
+
`torch.Tensor`: The reshaped tensor.
|
| 751 |
+
"""
|
| 752 |
+
|
| 753 |
+
head_size = self.heads
|
| 754 |
+
if tensor.ndim == 3:
|
| 755 |
+
batch_size, seq_len, dim = tensor.shape
|
| 756 |
+
extra_dim = 1
|
| 757 |
+
else:
|
| 758 |
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
| 759 |
+
tensor = tensor.reshape(
|
| 760 |
+
batch_size, seq_len * extra_dim, head_size, dim // head_size
|
| 761 |
+
)
|
| 762 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 763 |
+
|
| 764 |
+
if out_dim == 3:
|
| 765 |
+
tensor = tensor.reshape(
|
| 766 |
+
batch_size * head_size, seq_len * extra_dim, dim // head_size
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
return tensor
|
| 770 |
+
|
| 771 |
+
def get_attention_scores(
|
| 772 |
+
self,
|
| 773 |
+
query: torch.Tensor,
|
| 774 |
+
key: torch.Tensor,
|
| 775 |
+
attention_mask: torch.Tensor = None,
|
| 776 |
+
) -> torch.Tensor:
|
| 777 |
+
r"""
|
| 778 |
+
Compute the attention scores.
|
| 779 |
+
|
| 780 |
+
Args:
|
| 781 |
+
query (`torch.Tensor`): The query tensor.
|
| 782 |
+
key (`torch.Tensor`): The key tensor.
|
| 783 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
| 784 |
+
|
| 785 |
+
Returns:
|
| 786 |
+
`torch.Tensor`: The attention probabilities/scores.
|
| 787 |
+
"""
|
| 788 |
+
dtype = query.dtype
|
| 789 |
+
if self.upcast_attention:
|
| 790 |
+
query = query.float()
|
| 791 |
+
key = key.float()
|
| 792 |
+
|
| 793 |
+
if attention_mask is None:
|
| 794 |
+
baddbmm_input = torch.empty(
|
| 795 |
+
query.shape[0],
|
| 796 |
+
query.shape[1],
|
| 797 |
+
key.shape[1],
|
| 798 |
+
dtype=query.dtype,
|
| 799 |
+
device=query.device,
|
| 800 |
+
)
|
| 801 |
+
beta = 0
|
| 802 |
+
else:
|
| 803 |
+
baddbmm_input = attention_mask
|
| 804 |
+
beta = 1
|
| 805 |
+
|
| 806 |
+
attention_scores = torch.baddbmm(
|
| 807 |
+
baddbmm_input,
|
| 808 |
+
query,
|
| 809 |
+
key.transpose(-1, -2),
|
| 810 |
+
beta=beta,
|
| 811 |
+
alpha=self.scale,
|
| 812 |
+
)
|
| 813 |
+
del baddbmm_input
|
| 814 |
+
|
| 815 |
+
if self.upcast_softmax:
|
| 816 |
+
attention_scores = attention_scores.float()
|
| 817 |
+
|
| 818 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 819 |
+
del attention_scores
|
| 820 |
+
|
| 821 |
+
attention_probs = attention_probs.to(dtype)
|
| 822 |
+
|
| 823 |
+
return attention_probs
|
| 824 |
+
|
| 825 |
+
def prepare_attention_mask(
|
| 826 |
+
self,
|
| 827 |
+
attention_mask: torch.Tensor,
|
| 828 |
+
target_length: int,
|
| 829 |
+
batch_size: int,
|
| 830 |
+
out_dim: int = 3,
|
| 831 |
+
) -> torch.Tensor:
|
| 832 |
+
r"""
|
| 833 |
+
Prepare the attention mask for the attention computation.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
attention_mask (`torch.Tensor`):
|
| 837 |
+
The attention mask to prepare.
|
| 838 |
+
target_length (`int`):
|
| 839 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
| 840 |
+
batch_size (`int`):
|
| 841 |
+
The batch size, which is used to repeat the attention mask.
|
| 842 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
| 843 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
| 844 |
+
|
| 845 |
+
Returns:
|
| 846 |
+
`torch.Tensor`: The prepared attention mask.
|
| 847 |
+
"""
|
| 848 |
+
head_size = self.heads
|
| 849 |
+
if attention_mask is None:
|
| 850 |
+
return attention_mask
|
| 851 |
+
|
| 852 |
+
current_length: int = attention_mask.shape[-1]
|
| 853 |
+
if current_length != target_length:
|
| 854 |
+
if attention_mask.device.type == "mps":
|
| 855 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 856 |
+
# Instead, we can manually construct the padding tensor.
|
| 857 |
+
padding_shape = (
|
| 858 |
+
attention_mask.shape[0],
|
| 859 |
+
attention_mask.shape[1],
|
| 860 |
+
target_length,
|
| 861 |
+
)
|
| 862 |
+
padding = torch.zeros(
|
| 863 |
+
padding_shape,
|
| 864 |
+
dtype=attention_mask.dtype,
|
| 865 |
+
device=attention_mask.device,
|
| 866 |
+
)
|
| 867 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 868 |
+
else:
|
| 869 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
| 870 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
| 871 |
+
# remaining_length: int = target_length - current_length
|
| 872 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
| 873 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 874 |
+
|
| 875 |
+
if out_dim == 3:
|
| 876 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 877 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
| 878 |
+
elif out_dim == 4:
|
| 879 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 880 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
| 881 |
+
|
| 882 |
+
return attention_mask
|
| 883 |
+
|
| 884 |
+
def norm_encoder_hidden_states(
|
| 885 |
+
self, encoder_hidden_states: torch.Tensor
|
| 886 |
+
) -> torch.Tensor:
|
| 887 |
+
r"""
|
| 888 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 889 |
+
`Attention` class.
|
| 890 |
+
|
| 891 |
+
Args:
|
| 892 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
| 893 |
+
|
| 894 |
+
Returns:
|
| 895 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
| 896 |
+
"""
|
| 897 |
+
assert (
|
| 898 |
+
self.norm_cross is not None
|
| 899 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 900 |
+
|
| 901 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 902 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 903 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
| 904 |
+
# Group norm norms along the channels dimension and expects
|
| 905 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
| 906 |
+
# to norm along the hidden dimension, so we need to move
|
| 907 |
+
# (batch_size, sequence_length, hidden_size) ->
|
| 908 |
+
# (batch_size, hidden_size, sequence_length)
|
| 909 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 910 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 911 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 912 |
+
else:
|
| 913 |
+
assert False
|
| 914 |
+
|
| 915 |
+
return encoder_hidden_states
|
| 916 |
+
|
| 917 |
+
@staticmethod
|
| 918 |
+
def apply_rotary_emb(
|
| 919 |
+
input_tensor: torch.Tensor,
|
| 920 |
+
freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
|
| 921 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 922 |
+
cos_freqs = freqs_cis[0]
|
| 923 |
+
sin_freqs = freqs_cis[1]
|
| 924 |
+
|
| 925 |
+
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
| 926 |
+
t1, t2 = t_dup.unbind(dim=-1)
|
| 927 |
+
t_dup = torch.stack((-t2, t1), dim=-1)
|
| 928 |
+
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
| 929 |
+
|
| 930 |
+
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
| 931 |
+
|
| 932 |
+
return out
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class AttnProcessor2_0:
|
| 936 |
+
r"""
|
| 937 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 938 |
+
"""
|
| 939 |
+
|
| 940 |
+
def __init__(self):
|
| 941 |
+
pass
|
| 942 |
+
|
| 943 |
+
def __call__(
|
| 944 |
+
self,
|
| 945 |
+
attn: Attention,
|
| 946 |
+
hidden_states: torch.FloatTensor,
|
| 947 |
+
freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
|
| 948 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 949 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 950 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 951 |
+
skip_layer_mask: Optional[torch.FloatTensor] = None,
|
| 952 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 953 |
+
*args,
|
| 954 |
+
**kwargs,
|
| 955 |
+
) -> torch.FloatTensor:
|
| 956 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 957 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 958 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 959 |
+
|
| 960 |
+
residual = hidden_states
|
| 961 |
+
if attn.spatial_norm is not None:
|
| 962 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 963 |
+
|
| 964 |
+
input_ndim = hidden_states.ndim
|
| 965 |
+
|
| 966 |
+
if input_ndim == 4:
|
| 967 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 968 |
+
hidden_states = hidden_states.view(
|
| 969 |
+
batch_size, channel, height * width
|
| 970 |
+
).transpose(1, 2)
|
| 971 |
+
|
| 972 |
+
batch_size, sequence_length, _ = (
|
| 973 |
+
hidden_states.shape
|
| 974 |
+
if encoder_hidden_states is None
|
| 975 |
+
else encoder_hidden_states.shape
|
| 976 |
+
)
|
| 977 |
+
|
| 978 |
+
if skip_layer_mask is not None:
|
| 979 |
+
skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1)
|
| 980 |
+
|
| 981 |
+
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
| 982 |
+
attention_mask = attn.prepare_attention_mask(
|
| 983 |
+
attention_mask, sequence_length, batch_size
|
| 984 |
+
)
|
| 985 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 986 |
+
# (batch, heads, source_length, target_length)
|
| 987 |
+
attention_mask = attention_mask.view(
|
| 988 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
if attn.group_norm is not None:
|
| 992 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 993 |
+
1, 2
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
query = attn.to_q(hidden_states)
|
| 997 |
+
query = attn.q_norm(query)
|
| 998 |
+
|
| 999 |
+
if encoder_hidden_states is not None:
|
| 1000 |
+
if attn.norm_cross:
|
| 1001 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 1002 |
+
encoder_hidden_states
|
| 1003 |
+
)
|
| 1004 |
+
key = attn.to_k(encoder_hidden_states)
|
| 1005 |
+
key = attn.k_norm(key)
|
| 1006 |
+
else: # if no context provided do self-attention
|
| 1007 |
+
encoder_hidden_states = hidden_states
|
| 1008 |
+
key = attn.to_k(hidden_states)
|
| 1009 |
+
key = attn.k_norm(key)
|
| 1010 |
+
if attn.use_rope:
|
| 1011 |
+
key = attn.apply_rotary_emb(key, freqs_cis)
|
| 1012 |
+
query = attn.apply_rotary_emb(query, freqs_cis)
|
| 1013 |
+
|
| 1014 |
+
value = attn.to_v(encoder_hidden_states)
|
| 1015 |
+
value_for_stg = value
|
| 1016 |
+
|
| 1017 |
+
inner_dim = key.shape[-1]
|
| 1018 |
+
head_dim = inner_dim // attn.heads
|
| 1019 |
+
|
| 1020 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1021 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1022 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1023 |
+
|
| 1024 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 1025 |
+
|
| 1026 |
+
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
| 1027 |
+
q_segment_indexes = None
|
| 1028 |
+
if (
|
| 1029 |
+
attention_mask is not None
|
| 1030 |
+
): # if mask is required need to tune both segmenIds fields
|
| 1031 |
+
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
| 1032 |
+
attention_mask = attention_mask.to(torch.float32)
|
| 1033 |
+
q_segment_indexes = torch.ones(
|
| 1034 |
+
batch_size, query.shape[2], device=query.device, dtype=torch.float32
|
| 1035 |
+
)
|
| 1036 |
+
assert (
|
| 1037 |
+
attention_mask.shape[1] == key.shape[2]
|
| 1038 |
+
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
| 1039 |
+
|
| 1040 |
+
assert (
|
| 1041 |
+
query.shape[2] % 128 == 0
|
| 1042 |
+
), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
|
| 1043 |
+
assert (
|
| 1044 |
+
key.shape[2] % 128 == 0
|
| 1045 |
+
), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
|
| 1046 |
+
|
| 1047 |
+
# run the TPU kernel implemented in jax with pallas
|
| 1048 |
+
hidden_states_a = flash_attention(
|
| 1049 |
+
q=query,
|
| 1050 |
+
k=key,
|
| 1051 |
+
v=value,
|
| 1052 |
+
q_segment_ids=q_segment_indexes,
|
| 1053 |
+
kv_segment_ids=attention_mask,
|
| 1054 |
+
sm_scale=attn.scale,
|
| 1055 |
+
)
|
| 1056 |
+
else:
|
| 1057 |
+
hidden_states_a = F.scaled_dot_product_attention(
|
| 1058 |
+
query,
|
| 1059 |
+
key,
|
| 1060 |
+
value,
|
| 1061 |
+
attn_mask=attention_mask,
|
| 1062 |
+
dropout_p=0.0,
|
| 1063 |
+
is_causal=False,
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
hidden_states_a = hidden_states_a.transpose(1, 2).reshape(
|
| 1067 |
+
batch_size, -1, attn.heads * head_dim
|
| 1068 |
+
)
|
| 1069 |
+
hidden_states_a = hidden_states_a.to(query.dtype)
|
| 1070 |
+
|
| 1071 |
+
if (
|
| 1072 |
+
skip_layer_mask is not None
|
| 1073 |
+
and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
|
| 1074 |
+
):
|
| 1075 |
+
hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
|
| 1076 |
+
1.0 - skip_layer_mask
|
| 1077 |
+
)
|
| 1078 |
+
elif (
|
| 1079 |
+
skip_layer_mask is not None
|
| 1080 |
+
and skip_layer_strategy == SkipLayerStrategy.AttentionValues
|
| 1081 |
+
):
|
| 1082 |
+
hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
|
| 1083 |
+
1.0 - skip_layer_mask
|
| 1084 |
+
)
|
| 1085 |
+
else:
|
| 1086 |
+
hidden_states = hidden_states_a
|
| 1087 |
+
|
| 1088 |
+
# linear proj
|
| 1089 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 1090 |
+
# dropout
|
| 1091 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 1092 |
+
|
| 1093 |
+
if input_ndim == 4:
|
| 1094 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1095 |
+
batch_size, channel, height, width
|
| 1096 |
+
)
|
| 1097 |
+
if (
|
| 1098 |
+
skip_layer_mask is not None
|
| 1099 |
+
and skip_layer_strategy == SkipLayerStrategy.Residual
|
| 1100 |
+
):
|
| 1101 |
+
skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1)
|
| 1102 |
+
|
| 1103 |
+
if attn.residual_connection:
|
| 1104 |
+
if (
|
| 1105 |
+
skip_layer_mask is not None
|
| 1106 |
+
and skip_layer_strategy == SkipLayerStrategy.Residual
|
| 1107 |
+
):
|
| 1108 |
+
hidden_states = hidden_states + residual * skip_layer_mask
|
| 1109 |
+
else:
|
| 1110 |
+
hidden_states = hidden_states + residual
|
| 1111 |
+
|
| 1112 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1113 |
+
|
| 1114 |
+
return hidden_states
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
class AttnProcessor:
|
| 1118 |
+
r"""
|
| 1119 |
+
Default processor for performing attention-related computations.
|
| 1120 |
+
"""
|
| 1121 |
+
|
| 1122 |
+
def __call__(
|
| 1123 |
+
self,
|
| 1124 |
+
attn: Attention,
|
| 1125 |
+
hidden_states: torch.FloatTensor,
|
| 1126 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1127 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1128 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 1129 |
+
*args,
|
| 1130 |
+
**kwargs,
|
| 1131 |
+
) -> torch.Tensor:
|
| 1132 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 1133 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 1134 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 1135 |
+
|
| 1136 |
+
residual = hidden_states
|
| 1137 |
+
|
| 1138 |
+
if attn.spatial_norm is not None:
|
| 1139 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 1140 |
+
|
| 1141 |
+
input_ndim = hidden_states.ndim
|
| 1142 |
+
|
| 1143 |
+
if input_ndim == 4:
|
| 1144 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 1145 |
+
hidden_states = hidden_states.view(
|
| 1146 |
+
batch_size, channel, height * width
|
| 1147 |
+
).transpose(1, 2)
|
| 1148 |
+
|
| 1149 |
+
batch_size, sequence_length, _ = (
|
| 1150 |
+
hidden_states.shape
|
| 1151 |
+
if encoder_hidden_states is None
|
| 1152 |
+
else encoder_hidden_states.shape
|
| 1153 |
+
)
|
| 1154 |
+
attention_mask = attn.prepare_attention_mask(
|
| 1155 |
+
attention_mask, sequence_length, batch_size
|
| 1156 |
+
)
|
| 1157 |
+
|
| 1158 |
+
if attn.group_norm is not None:
|
| 1159 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 1160 |
+
1, 2
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
query = attn.to_q(hidden_states)
|
| 1164 |
+
|
| 1165 |
+
if encoder_hidden_states is None:
|
| 1166 |
+
encoder_hidden_states = hidden_states
|
| 1167 |
+
elif attn.norm_cross:
|
| 1168 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 1169 |
+
encoder_hidden_states
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
key = attn.to_k(encoder_hidden_states)
|
| 1173 |
+
value = attn.to_v(encoder_hidden_states)
|
| 1174 |
+
|
| 1175 |
+
query = attn.head_to_batch_dim(query)
|
| 1176 |
+
key = attn.head_to_batch_dim(key)
|
| 1177 |
+
value = attn.head_to_batch_dim(value)
|
| 1178 |
+
|
| 1179 |
+
query = attn.q_norm(query)
|
| 1180 |
+
key = attn.k_norm(key)
|
| 1181 |
+
|
| 1182 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 1183 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 1184 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 1185 |
+
|
| 1186 |
+
# linear proj
|
| 1187 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 1188 |
+
# dropout
|
| 1189 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 1190 |
+
|
| 1191 |
+
if input_ndim == 4:
|
| 1192 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1193 |
+
batch_size, channel, height, width
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
if attn.residual_connection:
|
| 1197 |
+
hidden_states = hidden_states + residual
|
| 1198 |
+
|
| 1199 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1200 |
+
|
| 1201 |
+
return hidden_states
|
| 1202 |
+
|
| 1203 |
+
|
| 1204 |
+
class FeedForward(nn.Module):
|
| 1205 |
+
r"""
|
| 1206 |
+
A feed-forward layer.
|
| 1207 |
+
|
| 1208 |
+
Parameters:
|
| 1209 |
+
dim (`int`): The number of channels in the input.
|
| 1210 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 1211 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 1212 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 1213 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 1214 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 1215 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 1216 |
+
"""
|
| 1217 |
+
|
| 1218 |
+
def __init__(
|
| 1219 |
+
self,
|
| 1220 |
+
dim: int,
|
| 1221 |
+
dim_out: Optional[int] = None,
|
| 1222 |
+
mult: int = 4,
|
| 1223 |
+
dropout: float = 0.0,
|
| 1224 |
+
activation_fn: str = "geglu",
|
| 1225 |
+
final_dropout: bool = False,
|
| 1226 |
+
inner_dim=None,
|
| 1227 |
+
bias: bool = True,
|
| 1228 |
+
):
|
| 1229 |
+
super().__init__()
|
| 1230 |
+
if inner_dim is None:
|
| 1231 |
+
inner_dim = int(dim * mult)
|
| 1232 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 1233 |
+
linear_cls = nn.Linear
|
| 1234 |
+
|
| 1235 |
+
if activation_fn == "gelu":
|
| 1236 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
| 1237 |
+
elif activation_fn == "gelu-approximate":
|
| 1238 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
| 1239 |
+
elif activation_fn == "geglu":
|
| 1240 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
| 1241 |
+
elif activation_fn == "geglu-approximate":
|
| 1242 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
| 1243 |
+
else:
|
| 1244 |
+
raise ValueError(f"Unsupported activation function: {activation_fn}")
|
| 1245 |
+
|
| 1246 |
+
self.net = nn.ModuleList([])
|
| 1247 |
+
# project in
|
| 1248 |
+
self.net.append(act_fn)
|
| 1249 |
+
# project dropout
|
| 1250 |
+
self.net.append(nn.Dropout(dropout))
|
| 1251 |
+
# project out
|
| 1252 |
+
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
| 1253 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 1254 |
+
if final_dropout:
|
| 1255 |
+
self.net.append(nn.Dropout(dropout))
|
| 1256 |
+
|
| 1257 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
| 1258 |
+
compatible_cls = (GEGLU, LoRACompatibleLinear)
|
| 1259 |
+
for module in self.net:
|
| 1260 |
+
if isinstance(module, compatible_cls):
|
| 1261 |
+
hidden_states = module(hidden_states, scale)
|
| 1262 |
+
else:
|
| 1263 |
+
hidden_states = module(hidden_states)
|
| 1264 |
+
return hidden_states
|
ltx_video/models/transformers/embeddings.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_timestep_embedding(
|
| 11 |
+
timesteps: torch.Tensor,
|
| 12 |
+
embedding_dim: int,
|
| 13 |
+
flip_sin_to_cos: bool = False,
|
| 14 |
+
downscale_freq_shift: float = 1,
|
| 15 |
+
scale: float = 1,
|
| 16 |
+
max_period: int = 10000,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 20 |
+
|
| 21 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 22 |
+
These may be fractional.
|
| 23 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 24 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
| 25 |
+
"""
|
| 26 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 27 |
+
|
| 28 |
+
half_dim = embedding_dim // 2
|
| 29 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 30 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 31 |
+
)
|
| 32 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 33 |
+
|
| 34 |
+
emb = torch.exp(exponent)
|
| 35 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 36 |
+
|
| 37 |
+
# scale embeddings
|
| 38 |
+
emb = scale * emb
|
| 39 |
+
|
| 40 |
+
# concat sine and cosine embeddings
|
| 41 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 42 |
+
|
| 43 |
+
# flip sine and cosine embeddings
|
| 44 |
+
if flip_sin_to_cos:
|
| 45 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 46 |
+
|
| 47 |
+
# zero pad
|
| 48 |
+
if embedding_dim % 2 == 1:
|
| 49 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 50 |
+
return emb
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
|
| 54 |
+
"""
|
| 55 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
| 56 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 57 |
+
"""
|
| 58 |
+
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
|
| 59 |
+
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
|
| 60 |
+
grid = grid.reshape([3, 1, w, h, f])
|
| 61 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 62 |
+
pos_embed = pos_embed.transpose(1, 0, 2, 3)
|
| 63 |
+
return rearrange(pos_embed, "h w f c -> (f h w) c")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 67 |
+
if embed_dim % 3 != 0:
|
| 68 |
+
raise ValueError("embed_dim must be divisible by 3")
|
| 69 |
+
|
| 70 |
+
# use half of dimensions to encode grid_h
|
| 71 |
+
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
|
| 72 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
|
| 73 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
|
| 74 |
+
|
| 75 |
+
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
|
| 76 |
+
return emb
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 80 |
+
"""
|
| 81 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 82 |
+
"""
|
| 83 |
+
if embed_dim % 2 != 0:
|
| 84 |
+
raise ValueError("embed_dim must be divisible by 2")
|
| 85 |
+
|
| 86 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 87 |
+
omega /= embed_dim / 2.0
|
| 88 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 89 |
+
|
| 90 |
+
pos_shape = pos.shape
|
| 91 |
+
|
| 92 |
+
pos = pos.reshape(-1)
|
| 93 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 94 |
+
out = out.reshape([*pos_shape, -1])[0]
|
| 95 |
+
|
| 96 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 97 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 98 |
+
|
| 99 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
|
| 100 |
+
return emb
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 104 |
+
"""Apply positional information to a sequence of embeddings.
|
| 105 |
+
|
| 106 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
| 107 |
+
them
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
embed_dim: (int): Dimension of the positional embedding.
|
| 111 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
| 116 |
+
super().__init__()
|
| 117 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
| 118 |
+
div_term = torch.exp(
|
| 119 |
+
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
|
| 120 |
+
)
|
| 121 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
| 122 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 123 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 124 |
+
self.register_buffer("pe", pe)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
_, seq_length, _ = x.shape
|
| 128 |
+
x = x + self.pe[:, :seq_length]
|
| 129 |
+
return x
|
ltx_video/models/transformers/symmetric_patchifier.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Patchifier(ConfigMixin, ABC):
|
| 11 |
+
def __init__(self, patch_size: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 17 |
+
raise NotImplementedError("Patchify method not implemented")
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def unpatchify(
|
| 21 |
+
self,
|
| 22 |
+
latents: Tensor,
|
| 23 |
+
output_height: int,
|
| 24 |
+
output_width: int,
|
| 25 |
+
out_channels: int,
|
| 26 |
+
) -> Tuple[Tensor, Tensor]:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def patch_size(self):
|
| 31 |
+
return self._patch_size
|
| 32 |
+
|
| 33 |
+
def get_latent_coords(
|
| 34 |
+
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
| 38 |
+
top-left corner latent coordinates of each latent patch.
|
| 39 |
+
The tensor is repeated for each batch element.
|
| 40 |
+
"""
|
| 41 |
+
latent_sample_coords = torch.meshgrid(
|
| 42 |
+
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
| 43 |
+
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
| 44 |
+
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
| 45 |
+
)
|
| 46 |
+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
| 47 |
+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
| 48 |
+
latent_coords = rearrange(
|
| 49 |
+
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
| 50 |
+
)
|
| 51 |
+
return latent_coords
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SymmetricPatchifier(Patchifier):
|
| 55 |
+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 56 |
+
b, _, f, h, w = latents.shape
|
| 57 |
+
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
| 58 |
+
latents = rearrange(
|
| 59 |
+
latents,
|
| 60 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 61 |
+
p1=self._patch_size[0],
|
| 62 |
+
p2=self._patch_size[1],
|
| 63 |
+
p3=self._patch_size[2],
|
| 64 |
+
)
|
| 65 |
+
return latents, latent_coords
|
| 66 |
+
|
| 67 |
+
def unpatchify(
|
| 68 |
+
self,
|
| 69 |
+
latents: Tensor,
|
| 70 |
+
output_height: int,
|
| 71 |
+
output_width: int,
|
| 72 |
+
out_channels: int,
|
| 73 |
+
) -> Tuple[Tensor, Tensor]:
|
| 74 |
+
output_height = output_height // self._patch_size[1]
|
| 75 |
+
output_width = output_width // self._patch_size[2]
|
| 76 |
+
latents = rearrange(
|
| 77 |
+
latents,
|
| 78 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 79 |
+
h=output_height,
|
| 80 |
+
w=output_width,
|
| 81 |
+
p=self._patch_size[1],
|
| 82 |
+
q=self._patch_size[2],
|
| 83 |
+
)
|
| 84 |
+
return latents
|
ltx_video/models/transformers/transformer3d.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Union
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import glob
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
| 15 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
| 16 |
+
from diffusers.utils import logging
|
| 17 |
+
from torch import nn
|
| 18 |
+
from safetensors import safe_open
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from ltx_video.models.transformers.attention import BasicTransformerBlock
|
| 22 |
+
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 23 |
+
|
| 24 |
+
from ltx_video.utils.diffusers_config_mapping import (
|
| 25 |
+
diffusers_and_ours_config_mapping,
|
| 26 |
+
make_hashable_key,
|
| 27 |
+
TRANSFORMER_KEYS_RENAME_DICT,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Transformer3DModelOutput(BaseOutput):
|
| 36 |
+
"""
|
| 37 |
+
The output of [`Transformer2DModel`].
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
| 41 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
| 42 |
+
distributions for the unnoised latent pixels.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
sample: torch.FloatTensor
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
| 49 |
+
_supports_gradient_checkpointing = True
|
| 50 |
+
|
| 51 |
+
@register_to_config
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
num_attention_heads: int = 16,
|
| 55 |
+
attention_head_dim: int = 88,
|
| 56 |
+
in_channels: Optional[int] = None,
|
| 57 |
+
out_channels: Optional[int] = None,
|
| 58 |
+
num_layers: int = 1,
|
| 59 |
+
dropout: float = 0.0,
|
| 60 |
+
norm_num_groups: int = 32,
|
| 61 |
+
cross_attention_dim: Optional[int] = None,
|
| 62 |
+
attention_bias: bool = False,
|
| 63 |
+
num_vector_embeds: Optional[int] = None,
|
| 64 |
+
activation_fn: str = "geglu",
|
| 65 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 66 |
+
use_linear_projection: bool = False,
|
| 67 |
+
only_cross_attention: bool = False,
|
| 68 |
+
double_self_attention: bool = False,
|
| 69 |
+
upcast_attention: bool = False,
|
| 70 |
+
adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
|
| 71 |
+
standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
|
| 72 |
+
norm_elementwise_affine: bool = True,
|
| 73 |
+
norm_eps: float = 1e-5,
|
| 74 |
+
attention_type: str = "default",
|
| 75 |
+
caption_channels: int = None,
|
| 76 |
+
use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
|
| 77 |
+
qk_norm: Optional[str] = None,
|
| 78 |
+
positional_embedding_type: str = "rope",
|
| 79 |
+
positional_embedding_theta: Optional[float] = None,
|
| 80 |
+
positional_embedding_max_pos: Optional[List[int]] = None,
|
| 81 |
+
timestep_scale_multiplier: Optional[float] = None,
|
| 82 |
+
causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.use_tpu_flash_attention = (
|
| 86 |
+
use_tpu_flash_attention # FIXME: push config down to the attention modules
|
| 87 |
+
)
|
| 88 |
+
self.use_linear_projection = use_linear_projection
|
| 89 |
+
self.num_attention_heads = num_attention_heads
|
| 90 |
+
self.attention_head_dim = attention_head_dim
|
| 91 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 92 |
+
self.inner_dim = inner_dim
|
| 93 |
+
self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
|
| 94 |
+
self.positional_embedding_type = positional_embedding_type
|
| 95 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 96 |
+
self.positional_embedding_max_pos = positional_embedding_max_pos
|
| 97 |
+
self.use_rope = self.positional_embedding_type == "rope"
|
| 98 |
+
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 99 |
+
|
| 100 |
+
if self.positional_embedding_type == "absolute":
|
| 101 |
+
raise ValueError("Absolute positional embedding is no longer supported")
|
| 102 |
+
elif self.positional_embedding_type == "rope":
|
| 103 |
+
if positional_embedding_theta is None:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
|
| 106 |
+
)
|
| 107 |
+
if positional_embedding_max_pos is None:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 3. Define transformers blocks
|
| 113 |
+
self.transformer_blocks = nn.ModuleList(
|
| 114 |
+
[
|
| 115 |
+
BasicTransformerBlock(
|
| 116 |
+
inner_dim,
|
| 117 |
+
num_attention_heads,
|
| 118 |
+
attention_head_dim,
|
| 119 |
+
dropout=dropout,
|
| 120 |
+
cross_attention_dim=cross_attention_dim,
|
| 121 |
+
activation_fn=activation_fn,
|
| 122 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 123 |
+
attention_bias=attention_bias,
|
| 124 |
+
only_cross_attention=only_cross_attention,
|
| 125 |
+
double_self_attention=double_self_attention,
|
| 126 |
+
upcast_attention=upcast_attention,
|
| 127 |
+
adaptive_norm=adaptive_norm,
|
| 128 |
+
standardization_norm=standardization_norm,
|
| 129 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 130 |
+
norm_eps=norm_eps,
|
| 131 |
+
attention_type=attention_type,
|
| 132 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 133 |
+
qk_norm=qk_norm,
|
| 134 |
+
use_rope=self.use_rope,
|
| 135 |
+
)
|
| 136 |
+
for d in range(num_layers)
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# 4. Define output layers
|
| 141 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 142 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 143 |
+
self.scale_shift_table = nn.Parameter(
|
| 144 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
| 145 |
+
)
|
| 146 |
+
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
| 147 |
+
|
| 148 |
+
self.adaln_single = AdaLayerNormSingle(
|
| 149 |
+
inner_dim, use_additional_conditions=False
|
| 150 |
+
)
|
| 151 |
+
if adaptive_norm == "single_scale":
|
| 152 |
+
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
| 153 |
+
|
| 154 |
+
self.caption_projection = None
|
| 155 |
+
if caption_channels is not None:
|
| 156 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
| 157 |
+
in_features=caption_channels, hidden_size=inner_dim
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.gradient_checkpointing = False
|
| 161 |
+
|
| 162 |
+
def set_use_tpu_flash_attention(self):
|
| 163 |
+
r"""
|
| 164 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 165 |
+
attention kernel.
|
| 166 |
+
"""
|
| 167 |
+
logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
|
| 168 |
+
self.use_tpu_flash_attention = True
|
| 169 |
+
# push config down to the attention modules
|
| 170 |
+
for block in self.transformer_blocks:
|
| 171 |
+
block.set_use_tpu_flash_attention()
|
| 172 |
+
|
| 173 |
+
def create_skip_layer_mask(
|
| 174 |
+
self,
|
| 175 |
+
batch_size: int,
|
| 176 |
+
num_conds: int,
|
| 177 |
+
ptb_index: int,
|
| 178 |
+
skip_block_list: Optional[List[int]] = None,
|
| 179 |
+
):
|
| 180 |
+
if skip_block_list is None or len(skip_block_list) == 0:
|
| 181 |
+
return None
|
| 182 |
+
num_layers = len(self.transformer_blocks)
|
| 183 |
+
mask = torch.ones(
|
| 184 |
+
(num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
|
| 185 |
+
)
|
| 186 |
+
for block_idx in skip_block_list:
|
| 187 |
+
mask[block_idx, ptb_index::num_conds] = 0
|
| 188 |
+
return mask
|
| 189 |
+
|
| 190 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 191 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 192 |
+
module.gradient_checkpointing = value
|
| 193 |
+
|
| 194 |
+
def get_fractional_positions(self, indices_grid):
|
| 195 |
+
fractional_positions = torch.stack(
|
| 196 |
+
[
|
| 197 |
+
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
| 198 |
+
for i in range(3)
|
| 199 |
+
],
|
| 200 |
+
dim=-1,
|
| 201 |
+
)
|
| 202 |
+
return fractional_positions
|
| 203 |
+
|
| 204 |
+
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
| 205 |
+
dtype = torch.float32 # We need full precision in the freqs_cis computation.
|
| 206 |
+
dim = self.inner_dim
|
| 207 |
+
theta = self.positional_embedding_theta
|
| 208 |
+
|
| 209 |
+
fractional_positions = self.get_fractional_positions(indices_grid)
|
| 210 |
+
|
| 211 |
+
start = 1
|
| 212 |
+
end = theta
|
| 213 |
+
device = fractional_positions.device
|
| 214 |
+
if spacing == "exp":
|
| 215 |
+
indices = theta ** (
|
| 216 |
+
torch.linspace(
|
| 217 |
+
math.log(start, theta),
|
| 218 |
+
math.log(end, theta),
|
| 219 |
+
dim // 6,
|
| 220 |
+
device=device,
|
| 221 |
+
dtype=dtype,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
indices = indices.to(dtype=dtype)
|
| 225 |
+
elif spacing == "exp_2":
|
| 226 |
+
indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
|
| 227 |
+
indices = indices.to(dtype=dtype)
|
| 228 |
+
elif spacing == "linear":
|
| 229 |
+
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
| 230 |
+
elif spacing == "sqrt":
|
| 231 |
+
indices = torch.linspace(
|
| 232 |
+
start**2, end**2, dim // 6, device=device, dtype=dtype
|
| 233 |
+
).sqrt()
|
| 234 |
+
|
| 235 |
+
indices = indices * math.pi / 2
|
| 236 |
+
|
| 237 |
+
if spacing == "exp_2":
|
| 238 |
+
freqs = (
|
| 239 |
+
(indices * fractional_positions.unsqueeze(-1))
|
| 240 |
+
.transpose(-1, -2)
|
| 241 |
+
.flatten(2)
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
freqs = (
|
| 245 |
+
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
| 246 |
+
.transpose(-1, -2)
|
| 247 |
+
.flatten(2)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 251 |
+
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
| 252 |
+
if dim % 6 != 0:
|
| 253 |
+
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
| 254 |
+
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
| 255 |
+
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 256 |
+
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 257 |
+
return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
|
| 258 |
+
|
| 259 |
+
def load_state_dict(
|
| 260 |
+
self,
|
| 261 |
+
state_dict: Dict,
|
| 262 |
+
*args,
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
|
| 266 |
+
state_dict = {
|
| 267 |
+
key.replace("model.diffusion_model.", ""): value
|
| 268 |
+
for key, value in state_dict.items()
|
| 269 |
+
if key.startswith("model.diffusion_model.")
|
| 270 |
+
}
|
| 271 |
+
super().load_state_dict(state_dict, *args, **kwargs)
|
| 272 |
+
|
| 273 |
+
@classmethod
|
| 274 |
+
def from_pretrained(
|
| 275 |
+
cls,
|
| 276 |
+
pretrained_model_path: Optional[Union[str, os.PathLike]],
|
| 277 |
+
*args,
|
| 278 |
+
**kwargs,
|
| 279 |
+
):
|
| 280 |
+
pretrained_model_path = Path(pretrained_model_path)
|
| 281 |
+
if pretrained_model_path.is_dir():
|
| 282 |
+
config_path = pretrained_model_path / "transformer" / "config.json"
|
| 283 |
+
with open(config_path, "r") as f:
|
| 284 |
+
config = make_hashable_key(json.load(f))
|
| 285 |
+
|
| 286 |
+
assert config in diffusers_and_ours_config_mapping, (
|
| 287 |
+
"Provided diffusers checkpoint config for transformer is not suppported. "
|
| 288 |
+
"We only support diffusers configs found in Lightricks/LTX-Video."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
config = diffusers_and_ours_config_mapping[config]
|
| 292 |
+
state_dict = {}
|
| 293 |
+
ckpt_paths = (
|
| 294 |
+
pretrained_model_path
|
| 295 |
+
/ "transformer"
|
| 296 |
+
/ "diffusion_pytorch_model*.safetensors"
|
| 297 |
+
)
|
| 298 |
+
dict_list = glob.glob(str(ckpt_paths))
|
| 299 |
+
for dict_path in dict_list:
|
| 300 |
+
part_dict = {}
|
| 301 |
+
with safe_open(dict_path, framework="pt", device="cpu") as f:
|
| 302 |
+
for k in f.keys():
|
| 303 |
+
part_dict[k] = f.get_tensor(k)
|
| 304 |
+
state_dict.update(part_dict)
|
| 305 |
+
|
| 306 |
+
for key in list(state_dict.keys()):
|
| 307 |
+
new_key = key
|
| 308 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 309 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 310 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 311 |
+
|
| 312 |
+
with torch.device("meta"):
|
| 313 |
+
transformer = cls.from_config(config)
|
| 314 |
+
transformer.load_state_dict(state_dict, assign=True, strict=True)
|
| 315 |
+
elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
|
| 316 |
+
".safetensors"
|
| 317 |
+
):
|
| 318 |
+
comfy_single_file_state_dict = {}
|
| 319 |
+
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
|
| 320 |
+
metadata = f.metadata()
|
| 321 |
+
for k in f.keys():
|
| 322 |
+
comfy_single_file_state_dict[k] = f.get_tensor(k)
|
| 323 |
+
configs = json.loads(metadata["config"])
|
| 324 |
+
transformer_config = configs["transformer"]
|
| 325 |
+
with torch.device("meta"):
|
| 326 |
+
transformer = Transformer3DModel.from_config(transformer_config)
|
| 327 |
+
transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
|
| 328 |
+
return transformer
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
hidden_states: torch.Tensor,
|
| 333 |
+
indices_grid: torch.Tensor,
|
| 334 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 335 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 336 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 337 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 338 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 339 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 340 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 341 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 342 |
+
return_dict: bool = True,
|
| 343 |
+
):
|
| 344 |
+
"""
|
| 345 |
+
The [`Transformer2DModel`] forward method.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 349 |
+
Input `hidden_states`.
|
| 350 |
+
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
| 351 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 352 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 353 |
+
self-attention.
|
| 354 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 355 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 356 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
| 357 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
| 358 |
+
`AdaLayerZeroNorm`.
|
| 359 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
| 360 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 361 |
+
`self.processor` in
|
| 362 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 363 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
| 364 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 365 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 366 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 367 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 368 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 369 |
+
|
| 370 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
| 371 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 372 |
+
|
| 373 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 374 |
+
above. This bias will be added to the cross-attention scores.
|
| 375 |
+
skip_layer_mask ( `torch.Tensor`, *optional*):
|
| 376 |
+
A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
|
| 377 |
+
`layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
|
| 378 |
+
skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
|
| 379 |
+
Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
|
| 380 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 381 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 382 |
+
tuple.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 386 |
+
`tuple` where the first element is the sample tensor.
|
| 387 |
+
"""
|
| 388 |
+
# for tpu attention offload 2d token masks are used. No need to transform.
|
| 389 |
+
if not self.use_tpu_flash_attention:
|
| 390 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
| 391 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
| 392 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
| 393 |
+
# expects mask of shape:
|
| 394 |
+
# [batch, key_tokens]
|
| 395 |
+
# adds singleton query_tokens dimension:
|
| 396 |
+
# [batch, 1, key_tokens]
|
| 397 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 398 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 399 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 400 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 401 |
+
# assume that mask is expressed as:
|
| 402 |
+
# (1 = keep, 0 = discard)
|
| 403 |
+
# convert mask into a bias that can be added to attention scores:
|
| 404 |
+
# (keep = +0, discard = -10000.0)
|
| 405 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 406 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 407 |
+
|
| 408 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 409 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 410 |
+
encoder_attention_mask = (
|
| 411 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
| 412 |
+
) * -10000.0
|
| 413 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 414 |
+
|
| 415 |
+
# 1. Input
|
| 416 |
+
hidden_states = self.patchify_proj(hidden_states)
|
| 417 |
+
|
| 418 |
+
if self.timestep_scale_multiplier:
|
| 419 |
+
timestep = self.timestep_scale_multiplier * timestep
|
| 420 |
+
|
| 421 |
+
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
| 422 |
+
|
| 423 |
+
batch_size = hidden_states.shape[0]
|
| 424 |
+
timestep, embedded_timestep = self.adaln_single(
|
| 425 |
+
timestep.flatten(),
|
| 426 |
+
{"resolution": None, "aspect_ratio": None},
|
| 427 |
+
batch_size=batch_size,
|
| 428 |
+
hidden_dtype=hidden_states.dtype,
|
| 429 |
+
)
|
| 430 |
+
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 431 |
+
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 432 |
+
embedded_timestep = embedded_timestep.view(
|
| 433 |
+
batch_size, -1, embedded_timestep.shape[-1]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# 2. Blocks
|
| 437 |
+
if self.caption_projection is not None:
|
| 438 |
+
batch_size = hidden_states.shape[0]
|
| 439 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 440 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
| 441 |
+
batch_size, -1, hidden_states.shape[-1]
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
for block_idx, block in enumerate(self.transformer_blocks):
|
| 445 |
+
if self.training and self.gradient_checkpointing:
|
| 446 |
+
|
| 447 |
+
def create_custom_forward(module, return_dict=None):
|
| 448 |
+
def custom_forward(*inputs):
|
| 449 |
+
if return_dict is not None:
|
| 450 |
+
return module(*inputs, return_dict=return_dict)
|
| 451 |
+
else:
|
| 452 |
+
return module(*inputs)
|
| 453 |
+
|
| 454 |
+
return custom_forward
|
| 455 |
+
|
| 456 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 457 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 458 |
+
)
|
| 459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 460 |
+
create_custom_forward(block),
|
| 461 |
+
hidden_states,
|
| 462 |
+
freqs_cis,
|
| 463 |
+
attention_mask,
|
| 464 |
+
encoder_hidden_states,
|
| 465 |
+
encoder_attention_mask,
|
| 466 |
+
timestep,
|
| 467 |
+
cross_attention_kwargs,
|
| 468 |
+
class_labels,
|
| 469 |
+
(
|
| 470 |
+
skip_layer_mask[block_idx]
|
| 471 |
+
if skip_layer_mask is not None
|
| 472 |
+
else None
|
| 473 |
+
),
|
| 474 |
+
skip_layer_strategy,
|
| 475 |
+
**ckpt_kwargs,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
hidden_states = block(
|
| 479 |
+
hidden_states,
|
| 480 |
+
freqs_cis=freqs_cis,
|
| 481 |
+
attention_mask=attention_mask,
|
| 482 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 483 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 484 |
+
timestep=timestep,
|
| 485 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 486 |
+
class_labels=class_labels,
|
| 487 |
+
skip_layer_mask=(
|
| 488 |
+
skip_layer_mask[block_idx]
|
| 489 |
+
if skip_layer_mask is not None
|
| 490 |
+
else None
|
| 491 |
+
),
|
| 492 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# 3. Output
|
| 496 |
+
scale_shift_values = (
|
| 497 |
+
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
| 498 |
+
)
|
| 499 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 500 |
+
hidden_states = self.norm_out(hidden_states)
|
| 501 |
+
# Modulation
|
| 502 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 503 |
+
hidden_states = self.proj_out(hidden_states)
|
| 504 |
+
if not return_dict:
|
| 505 |
+
return (hidden_states,)
|
| 506 |
+
|
| 507 |
+
return Transformer3DModelOutput(sample=hidden_states)
|
ltx_video/pipelines/__init__.py
ADDED
|
File without changes
|
ltx_video/pipelines/crf_compressor.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import av
|
| 2 |
+
import torch
|
| 3 |
+
import io
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
|
| 8 |
+
container = av.open(output_file, "w", format="mp4")
|
| 9 |
+
try:
|
| 10 |
+
stream = container.add_stream(
|
| 11 |
+
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
|
| 12 |
+
)
|
| 13 |
+
stream.height = image_array.shape[0]
|
| 14 |
+
stream.width = image_array.shape[1]
|
| 15 |
+
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
|
| 16 |
+
format="yuv420p"
|
| 17 |
+
)
|
| 18 |
+
container.mux(stream.encode(av_frame))
|
| 19 |
+
container.mux(stream.encode())
|
| 20 |
+
finally:
|
| 21 |
+
container.close()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _decode_single_frame(video_file):
|
| 25 |
+
container = av.open(video_file)
|
| 26 |
+
try:
|
| 27 |
+
stream = next(s for s in container.streams if s.type == "video")
|
| 28 |
+
frame = next(container.decode(stream))
|
| 29 |
+
finally:
|
| 30 |
+
container.close()
|
| 31 |
+
return frame.to_ndarray(format="rgb24")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def compress(image: torch.Tensor, crf=29):
|
| 35 |
+
if crf == 0:
|
| 36 |
+
return image
|
| 37 |
+
|
| 38 |
+
image_array = (
|
| 39 |
+
(image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
|
| 40 |
+
.byte()
|
| 41 |
+
.cpu()
|
| 42 |
+
.numpy()
|
| 43 |
+
)
|
| 44 |
+
with io.BytesIO() as output_file:
|
| 45 |
+
_encode_single_frame(output_file, image_array, crf)
|
| 46 |
+
video_bytes = output_file.getvalue()
|
| 47 |
+
with io.BytesIO(video_bytes) as video_file:
|
| 48 |
+
image_array = _decode_single_frame(video_file)
|
| 49 |
+
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
|
| 50 |
+
return tensor
|