|
|
from safetensors.torch import save_file, load_file |
|
|
import torch |
|
|
|
|
|
def merge_model_components( |
|
|
unet_path, |
|
|
vae_path, |
|
|
text_encoder_path, |
|
|
output_path |
|
|
): |
|
|
""" |
|
|
Merge UNet, VAE, and text encoder into a single safetensors file. |
|
|
|
|
|
Args: |
|
|
unet_path: Path to the main model/unet safetensors file |
|
|
vae_path: Path to the VAE safetensors file |
|
|
text_encoder_path: Path to the text encoder/CLIP safetensors file |
|
|
output_path: Path where the merged file will be saved |
|
|
""" |
|
|
|
|
|
print("Loading UNet/Model weights...") |
|
|
unet_state = load_file(unet_path) |
|
|
|
|
|
print("Loading VAE weights...") |
|
|
vae_state = load_file(vae_path) |
|
|
|
|
|
print("Loading Text Encoder weights...") |
|
|
text_encoder_state = load_file(text_encoder_path) |
|
|
|
|
|
|
|
|
print("Merging state dictionaries...") |
|
|
merged_state = {} |
|
|
|
|
|
|
|
|
merged_state.update(unet_state) |
|
|
|
|
|
|
|
|
for key, value in vae_state.items(): |
|
|
|
|
|
if not key.startswith('vae.'): |
|
|
merged_state[f'vae.{key}'] = value |
|
|
else: |
|
|
merged_state[key] = value |
|
|
|
|
|
|
|
|
for key, value in text_encoder_state.items(): |
|
|
|
|
|
if not key.startswith('text_encoder.'): |
|
|
merged_state[f'text_encoder.{key}'] = value |
|
|
else: |
|
|
merged_state[key] = value |
|
|
|
|
|
print(f"Total parameters in merged model: {len(merged_state)}") |
|
|
print(f"Saving merged model to {output_path}...") |
|
|
|
|
|
|
|
|
save_file(merged_state, output_path) |
|
|
|
|
|
print("✅ Merge complete!") |
|
|
print(f"File saved to: {output_path}") |
|
|
|
|
|
|
|
|
import os |
|
|
size_gb = os.path.getsize(output_path) / (1024**3) |
|
|
print(f"File size: {size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
merge_model_components( |
|
|
unet_path="flux1-depth-dev.safetensors", |
|
|
vae_path="vae/diffusion_pytorch_model.safetensors", |
|
|
text_encoder_path="text_encoder/model.safetensors", |
|
|
output_path="flux1-depth-dev_merged_model.safetensors" |
|
|
) |
|
|
|