euiia commited on
Commit
fbc88a7
·
verified ·
1 Parent(s): 7866a75

Update managers/ltx_manager.py

Browse files
Files changed (1) hide show
  1. managers/ltx_manager.py +198 -171
managers/ltx_manager.py CHANGED
@@ -2,193 +2,63 @@
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.1.0
6
  #
7
- # This file manages the LTX-Video specialist pool. It now includes a crucial
8
- # "monkey patch" for the LTX pipeline's `prepare_conditioning` method. This approach
9
- # isolates our ADUC-specific modifications from the original library code, ensuring
10
- # better maintainability and respecting the principle of separation of concerns.
11
 
12
  import torch
13
  import gc
14
  import os
 
15
  import yaml
16
  import logging
17
  import huggingface_hub
18
  import time
19
  import threading
 
 
20
  from typing import Optional, List, Tuple, Union
21
 
22
- from tools.optimization import optimize_ltx_worker, can_optimize_fp8
23
- from tools.hardware_manager import hardware_manager
24
- from managers.ltx_pipeline_utils import create_ltx_video_pipeline, calculate_padding
25
- from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem
26
- from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
27
- from ltx_video.pipelines.pipeline_ltx_video import LTXMultiScalePipeline
28
- from diffusers.utils.torch_utils import randn_tensor
29
 
30
  logger = logging.getLogger(__name__)
31
 
 
 
 
 
32
 
33
- # --- MONKEY PATCHING SECTION ---
34
- # This section contains our custom logic that will override the default
35
- # behavior of the LTX pipeline at runtime.
36
-
37
- def _aduc_prepare_conditioning_patch(
38
- self: LTXVideoPipeline, # 'self' will be the instance of the LTXVideoPipeline
39
- conditioning_items: Optional[List[Union[ConditioningItem, "LatentConditioningItem"]]],
40
- init_latents: torch.Tensor,
41
- num_frames: int,
42
- height: int,
43
- width: int,
44
- vae_per_channel_normalize: bool = False,
45
- generator=None,
46
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
47
- """
48
- This is our custom version of the `prepare_conditioning` method.
49
- It correctly handles both standard ConditioningItem (from pixels) and our
50
- ADUC-specific LatentConditioningItem (from latents), which the original
51
- method does not. This function will replace the original one at runtime.
52
- """
53
- if not conditioning_items:
54
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
55
- init_pixel_coords = latent_to_pixel_coords(
56
- init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning
57
- )
58
- return init_latents, init_pixel_coords, None, 0
59
-
60
- init_conditioning_mask = torch.zeros(init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device)
61
- extra_conditioning_latents = []
62
- extra_conditioning_pixel_coords = []
63
- extra_conditioning_mask = []
64
- extra_conditioning_num_latents = 0
65
-
66
- is_latent_mode = hasattr(conditioning_items[0], 'latent_tensor')
67
-
68
- if is_latent_mode:
69
- for item in conditioning_items:
70
- media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
71
- media_frame_number = item.media_frame_number
72
- strength = item.conditioning_strength
73
-
74
- if media_frame_number == 0:
75
- f_l, h_l, w_l = media_item_latents.shape[-3:]
76
- init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
77
- init_conditioning_mask[:, :f_l, :h_l, :w_l] = strength
78
- else:
79
- noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
80
- media_item_latents = torch.lerp(noise, media_item_latents, strength)
81
- patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
82
- pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
83
- pixel_coords[:, 0] += media_frame_number
84
- extra_conditioning_num_latents += patched_latents.shape[1]
85
- new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
86
- extra_conditioning_latents.append(patched_latents)
87
- extra_conditioning_pixel_coords.append(pixel_coords)
88
- extra_conditioning_mask.append(new_mask)
89
- else: # Original pixel-based logic for fallback
90
- for item in conditioning_items:
91
- if not isinstance(item, ConditioningItem): continue
92
- item = self._resize_conditioning_item(item, height, width)
93
- media_item_latents = vae_encode(
94
- item.media_item.to(dtype=self.vae.dtype, device=self.vae.device),
95
- self.vae, vae_per_channel_normalize=vae_per_channel_normalize
96
- ).to(dtype=init_latents.dtype)
97
- media_frame_number = item.media_frame_number
98
- strength = item.conditioning_strength
99
- if media_frame_number == 0:
100
- media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
101
- f_l, h_l, w_l = media_item_latents.shape[-3:]
102
- init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
103
- init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
104
- else:
105
- logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
106
- pass
107
-
108
- init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
109
- init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
110
- init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
111
- init_conditioning_mask = init_conditioning_mask.squeeze(-1)
112
-
113
- if extra_conditioning_latents:
114
- init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
115
- init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
116
- init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
117
- if self.transformer.use_tpu_flash_attention:
118
- init_latents = init_latents[:, :-extra_conditioning_num_latents]
119
- init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
120
- init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
121
-
122
- return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
123
-
124
- # --- END OF MONKEY PATCHING SECTION ---
125
 
126
 
127
- class LtxWorker:
128
- """
129
- Represents a single instance of the LTX-Video pipeline on a specific device.
130
- Manages model loading to CPU and movement to/from GPU.
131
- """
132
- def __init__(self, device_id, ltx_config_file):
133
- self.cpu_device = torch.device('cpu')
134
- self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
135
- logger.info(f"LTX Worker ({self.device}): Initializing with config '{ltx_config_file}'...")
136
-
137
- with open(ltx_config_file, "r") as file:
138
- self.config = yaml.safe_load(file)
139
-
140
- self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
141
-
142
- models_dir = "downloaded_models_gradio"
143
-
144
- logger.info(f"LTX Worker ({self.device}): Loading model to CPU...")
145
- model_path = os.path.join(models_dir, self.config["checkpoint_path"])
146
- if not os.path.exists(model_path):
147
- model_path = huggingface_hub.hf_hub_download(
148
- repo_id="Lightricks/LTX-Video", filename=self.config["checkpoint_path"],
149
- local_dir=models_dir, local_dir_use_symlinks=False
150
- )
151
-
152
- self.pipeline = create_ltx_video_pipeline(
153
- ckpt_path=model_path, precision=self.config["precision"],
154
- text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
155
- sampler=self.config["sampler"], device='cpu'
156
- )
157
- logger.info(f"LTX Worker ({self.device}): Model ready on CPU. Is distilled model? {self.is_distilled}")
158
-
159
- def to_gpu(self):
160
- """Moves the pipeline to the designated GPU AND optimizes if possible."""
161
- if self.device.type == 'cpu': return
162
- logger.info(f"LTX Worker: Moving pipeline to GPU {self.device}...")
163
- self.pipeline.to(self.device)
164
-
165
- if self.device.type == 'cuda' and can_optimize_fp8():
166
- logger.info(f"LTX Worker ({self.device}): FP8 supported GPU detected. Optimizing...")
167
- optimize_ltx_worker(self)
168
- logger.info(f"LTX Worker ({self.device}): Optimization complete.")
169
- elif self.device.type == 'cuda':
170
- logger.info(f"LTX Worker ({self.device}): FP8 optimization not supported or disabled.")
171
-
172
- def to_cpu(self):
173
- """Moves the pipeline back to the CPU and frees GPU memory."""
174
- if self.device.type == 'cpu': return
175
- logger.info(f"LTX Worker: Unloading pipeline from GPU {self.device}...")
176
- self.pipeline.to('cpu')
177
- gc.collect()
178
- if torch.cuda.is_available(): torch.cuda.empty_cache()
179
-
180
- def generate_video_fragment_internal(self, **kwargs):
181
- """Invokes the generation pipeline."""
182
- return self.pipeline(**kwargs).images
183
-
184
  class LtxPoolManager:
185
  """
186
  Manages a pool of LtxWorkers for optimized multi-GPU usage.
187
- HOT START MODE: Keeps all models loaded in VRAM for minimum latency.
188
  """
189
- def __init__(self, device_ids, ltx_config_file):
190
  logger.info(f"LTX POOL MANAGER: Creating workers for devices: {device_ids}")
191
- self.workers = [LtxWorker(dev_id, ltx_config_file) for dev_id in device_ids]
 
 
 
 
 
 
 
192
  self.current_worker_index = 0
193
  self.lock = threading.Lock()
194
 
@@ -202,10 +72,45 @@ class LtxPoolManager:
202
  else:
203
  logger.info("LTX POOL MANAGER: Operating in CPU or mixed mode. GPU pre-warming skipped.")
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def _apply_ltx_pipeline_patches(self):
206
- """
207
- Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility.
208
- """
209
  logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
210
  for worker in self.workers:
211
  worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
@@ -217,7 +122,7 @@ class LtxPoolManager:
217
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
218
  return worker
219
 
220
- def _prepare_pipeline_params(self, worker: LtxWorker, **kwargs) -> dict:
221
  pipeline_params = {
222
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
223
  "frame_rate": kwargs.get('video_fps', 24),
@@ -297,12 +202,134 @@ class LtxPoolManager:
297
  with torch.cuda.device(worker_to_use.device):
298
  gc.collect(); torch.cuda.empty_cache()
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  # --- Singleton Instantiation ---
301
- logger.info("Reading config.yaml to initialize LTX Pool Manager...")
302
  with open("config.yaml", 'r') as f:
303
  config = yaml.safe_load(f)
304
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']
305
  ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
306
- ltx_config_path = config['specialists']['ltx']['config_file']
307
- ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file=ltx_config_path)
308
  logger.info("Video Specialist (LTX) ready.")
 
2
  #
3
  # Copyright (C) August 4, 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 2.2.0
6
  #
7
+ # This file manages the LTX-Video specialist pool. It has been refactored to be
8
+ # self-contained by automatically cloning its own dependencies from the official
9
+ # LTX-Video repository. This modular approach makes the ADUC-SDR framework
10
+ # robust, portable, and easy to maintain.
11
 
12
  import torch
13
  import gc
14
  import os
15
+ import sys
16
  import yaml
17
  import logging
18
  import huggingface_hub
19
  import time
20
  import threading
21
+ import subprocess
22
+ from pathlib import Path
23
  from typing import Optional, List, Tuple, Union
24
 
25
+ from optimization import optimize_ltx_worker, can_optimize_fp8
26
+ from hardware_manager import hardware_manager
 
 
 
 
 
27
 
28
  logger = logging.getLogger(__name__)
29
 
30
+ # --- Dependency Management ---
31
+ DEPS_DIR = Path("./deps")
32
+ LTX_VIDEO_REPO_DIR = DEPS_DIR / "LTX-Video"
33
+ LTX_VIDEO_REPO_URL = "https://github.com/Lightricks/LTX-Video.git"
34
 
35
+ # --- Placeholder for lazy-loaded modules ---
36
+ create_ltx_video_pipeline = None
37
+ calculate_padding = None
38
+ LTXVideoPipeline = None
39
+ ConditioningItem = None
40
+ LatentConditioningItem = None
41
+ LTXMultiScalePipeline = None
42
+ vae_encode = None
43
+ latent_to_pixel_coords = None
44
+ randn_tensor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  class LtxPoolManager:
48
  """
49
  Manages a pool of LtxWorkers for optimized multi-GPU usage.
50
+ Handles its own code dependencies by cloning the LTX-Video repository.
51
  """
52
+ def __init__(self, device_ids, ltx_config_file_name):
53
  logger.info(f"LTX POOL MANAGER: Creating workers for devices: {device_ids}")
54
+ self._ltx_modules_loaded = False
55
+ self._setup_dependencies()
56
+ self._lazy_load_ltx_modules()
57
+
58
+ # Adjust config path to be inside the cloned repo
59
+ self.ltx_config_file = LTX_VIDEO_REPO_DIR / "configs" / ltx_config_file_name
60
+
61
+ self.workers = [LtxWorker(dev_id, self.ltx_config_file) for dev_id in self]
62
  self.current_worker_index = 0
63
  self.lock = threading.Lock()
64
 
 
72
  else:
73
  logger.info("LTX POOL MANAGER: Operating in CPU or mixed mode. GPU pre-warming skipped.")
74
 
75
+ def _setup_dependencies(self):
76
+ """Clones the LTX-Video repo if not found and adds it to the system path."""
77
+ if not LTX_VIDEO_REPO_DIR.exists():
78
+ logger.info(f"LTX-Video repository not found at '{LTX_VIDEO_REPO_DIR}'. Cloning from GitHub...")
79
+ try:
80
+ DEPS_DIR.mkdir(exist_ok=True)
81
+ subprocess.run(
82
+ ["git", "clone", LTX_VIDEO_REPO_URL, str(LTX_VIDEO_REPO_DIR)],
83
+ check=True, capture_output=True, text=True
84
+ )
85
+ logger.info("LTX-Video repository cloned successfully.")
86
+ except subprocess.CalledProcessError as e:
87
+ logger.error(f"Failed to clone LTX-Video repository. Git stderr: {e.stderr}")
88
+ raise RuntimeError("Could not clone the required LTX-Video dependency from GitHub.")
89
+ else:
90
+ logger.info("Found local LTX-Video repository.")
91
+
92
+ if str(LTX_VIDEO_REPO_DIR.resolve()) not in sys.path:
93
+ sys.path.insert(0, str(LTX_VIDEO_REPO_DIR.resolve()))
94
+ logger.info(f"Added '{LTX_VIDEO_REPO_DIR.resolve()}' to sys.path.")
95
+
96
+ def _lazy_load_ltx_modules(self):
97
+ """Dynamically imports LTX-Video modules after ensuring the repo exists."""
98
+ if self._ltx_modules_loaded:
99
+ return
100
+
101
+ global create_ltx_video_pipeline, calculate_padding, LTXVideoPipeline, ConditioningItem, LatentConditioningItem
102
+ global vae_encode, latent_to_pixel_coords, LTXMultiScalePipeline, randn_tensor
103
+
104
+ from inference import create_ltx_video_pipeline, calculate_padding
105
+ from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline, ConditioningItem, LatentConditioningItem, LTXMultiScalePipeline
106
+ from ltx_video.models.autoencoders.vae_encode import vae_encode, latent_to_pixel_coords
107
+ from diffusers.utils.torch_utils import randn_tensor
108
+
109
+ self._ltx_modules_loaded = True
110
+ logger.info("LTX-Video modules have been dynamically loaded.")
111
+
112
  def _apply_ltx_pipeline_patches(self):
113
+ """Applies runtime patches to the LTX pipeline for ADUC-SDR compatibility."""
 
 
114
  logger.info("LTX POOL MANAGER: Applying ADUC-SDR patches to LTX pipeline...")
115
  for worker in self.workers:
116
  worker.pipeline.prepare_conditioning = _aduc_prepare_conditioning_patch.__get__(worker.pipeline, LTXVideoPipeline)
 
122
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
123
  return worker
124
 
125
+ def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
126
  pipeline_params = {
127
  "height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
128
  "frame_rate": kwargs.get('video_fps', 24),
 
202
  with torch.cuda.device(worker_to_use.device):
203
  gc.collect(); torch.cuda.empty_cache()
204
 
205
+ class LtxWorker:
206
+ """
207
+ Represents a single instance of the LTX-Video pipeline on a specific device.
208
+ """
209
+ def __init__(self, device_id, ltx_config_file):
210
+ self.cpu_device = torch.device('cpu')
211
+ self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
212
+ logger.info(f"LTX Worker ({self.device}): Initializing with config '{ltx_config_file}'...")
213
+
214
+ with open(ltx_config_file, "r") as file:
215
+ self.config = yaml.safe_load(file)
216
+
217
+ self.is_distilled = "distilled" in self.config.get("checkpoint_path", "")
218
+
219
+ models_dir = LTX_VIDEO_REPO_DIR / "models_downloaded"
220
+
221
+ logger.info(f"LTX Worker ({self.device}): Preparing to load model...")
222
+ model_filename = self.config["checkpoint_path"]
223
+ model_path = huggingface_hub.hf_hub_download(
224
+ repo_id="Lightricks/LTX-Video", filename=model_filename,
225
+ local_dir=str(models_dir), local_dir_use_symlinks=False
226
+ )
227
+
228
+ self.pipeline = create_ltx_video_pipeline(
229
+ ckpt_path=model_path,
230
+ precision=self.config["precision"],
231
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
232
+ sampler=self.config["sampler"],
233
+ device='cpu'
234
+ )
235
+ logger.info(f"LTX Worker ({self.device}): Model ready on CPU. Is distilled model? {self.is_distilled}")
236
+
237
+ def to_gpu(self):
238
+ if self.device.type == 'cpu': return
239
+ logger.info(f"LTX Worker: Moving pipeline to GPU {self.device}...")
240
+ self.pipeline.to(self.device)
241
+ if self.device.type == 'cuda' and can_optimize_fp8():
242
+ logger.info(f"LTX Worker ({self.device}): FP8 supported GPU detected. Optimizing...")
243
+ optimize_ltx_worker(self)
244
+ logger.info(f"LTX Worker ({self.device}): Optimization complete.")
245
+ elif self.device.type == 'cuda':
246
+ logger.info(f"LTX Worker ({self.device}): FP8 optimization not supported or disabled.")
247
+
248
+ def to_cpu(self):
249
+ if self.device.type == 'cpu': return
250
+ logger.info(f"LTX Worker: Unloading pipeline from GPU {self.device}...")
251
+ self.pipeline.to('cpu')
252
+ gc.collect()
253
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
254
+
255
+ def generate_video_fragment_internal(self, **kwargs):
256
+ return self.pipeline(**kwargs).images
257
+
258
+
259
+ def _aduc_prepare_conditioning_patch(
260
+ self: LTXVideoPipeline,
261
+ conditioning_items: Optional[List[Union[ConditioningItem, "LatentConditioningItem"]]],
262
+ init_latents: torch.Tensor,
263
+ num_frames: int,
264
+ height: int,
265
+ width: int,
266
+ vae_per_channel_normalize: bool = False,
267
+ generator=None,
268
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
269
+ if not conditioning_items:
270
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
271
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
272
+ return init_latents, init_pixel_coords, None, 0
273
+ init_conditioning_mask = torch.zeros(init_latents[:, 0, :, :, :].shape, dtype=torch.float32, device=init_latents.device)
274
+ extra_conditioning_latents = []
275
+ extra_conditioning_pixel_coords = []
276
+ extra_conditioning_mask = []
277
+ extra_conditioning_num_latents = 0
278
+ is_latent_mode = hasattr(conditioning_items[0], 'latent_tensor')
279
+ if is_latent_mode:
280
+ for item in conditioning_items:
281
+ media_item_latents = item.latent_tensor.to(dtype=init_latents.dtype, device=init_latents.device)
282
+ media_frame_number = item.media_frame_number
283
+ strength = item.conditioning_strength
284
+ if media_frame_number == 0:
285
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
286
+ init_latents[:, :, :f_l, :h_l, :w_l] = torch.lerp(init_latents[:, :, :f_l, :h_l, :w_l], media_item_latents, strength)
287
+ init_conditioning_mask[:, :f_l, :h_l, :w_l] = strength
288
+ else:
289
+ noise = randn_tensor(media_item_latents.shape, generator=generator, device=media_item_latents.device, dtype=media_item_latents.dtype)
290
+ media_item_latents = torch.lerp(noise, media_item_latents, strength)
291
+ patched_latents, latent_coords = self.patchifier.patchify(latents=media_item_latents)
292
+ pixel_coords = latent_to_pixel_coords(latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
293
+ pixel_coords[:, 0] += media_frame_number
294
+ extra_conditioning_num_latents += patched_latents.shape[1]
295
+ new_mask = torch.full(patched_latents.shape[:2], strength, dtype=torch.float32, device=init_latents.device)
296
+ extra_conditioning_latents.append(patched_latents)
297
+ extra_conditioning_pixel_coords.append(pixel_coords)
298
+ extra_conditioning_mask.append(new_mask)
299
+ else:
300
+ for item in conditioning_items:
301
+ if not isinstance(item, ConditioningItem): continue
302
+ item = self._resize_conditioning_item(item, height, width)
303
+ media_item_latents = vae_encode(item.media_item.to(dtype=self.vae.dtype, device=self.vae.device), self.vae, vae_per_channel_normalize=vae_per_channel_normalize).to(dtype=init_latents.dtype)
304
+ media_frame_number = item.media_frame_number
305
+ strength = item.conditioning_strength
306
+ if media_frame_number == 0:
307
+ media_item_latents, l_x, l_y = self._get_latent_spatial_position(media_item_latents, item, height, width, strip_latent_border=True)
308
+ f_l, h_l, w_l = media_item_latents.shape[-3:]
309
+ init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = torch.lerp(init_latents[:, :, :f_l, l_y:l_y+h_l, l_x:l_x+w_l], media_item_latents, strength)
310
+ init_conditioning_mask[:, :f_l, l_y:l_y+h_l, l_x:l_x+w_l] = strength
311
+ else:
312
+ logger.warning("Pixel-based conditioning for non-zero frames is not fully implemented in this patch.")
313
+ init_latents, init_latent_coords = self.patchifier.patchify(latents=init_latents)
314
+ init_pixel_coords = latent_to_pixel_coords(init_latent_coords, self.vae, causal_fix=self.transformer.config.causal_temporal_positioning)
315
+ init_conditioning_mask, _ = self.patchifier.patchify(latents=init_conditioning_mask.unsqueeze(1))
316
+ init_conditioning_mask = init_conditioning_mask.squeeze(-1)
317
+ if extra_conditioning_latents:
318
+ init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
319
+ init_pixel_coords = torch.cat([*extra_conditioning_pixel_coords, init_pixel_coords], dim=2)
320
+ init_conditioning_mask = torch.cat([*extra_conditioning_mask, init_conditioning_mask], dim=1)
321
+ if self.transformer.use_tpu_flash_attention:
322
+ init_latents = init_latents[:, :-extra_conditioning_num_latents]
323
+ init_pixel_coords = init_pixel_coords[:, :, :-extra_conditioning_num_latents]
324
+ init_conditioning_mask = init_conditioning_mask[:, :-extra_conditioning_num_latents]
325
+ return init_latents, init_pixel_coords, init_conditioning_mask, extra_conditioning_num_latents
326
+
327
+
328
  # --- Singleton Instantiation ---
 
329
  with open("config.yaml", 'r') as f:
330
  config = yaml.safe_load(f)
331
  ltx_gpus_required = config['specialists']['ltx']['gpus_required']
332
  ltx_device_ids = hardware_manager.allocate_gpus('LTX', ltx_gpus_required)
333
+ ltx_config_filename = config['specialists']['ltx']['config_file']
334
+ ltx_manager_singleton = LtxPoolManager(device_ids=ltx_device_ids, ltx_config_file_name=ltx_config_filename)
335
  logger.info("Video Specialist (LTX) ready.")