Spaces:
Runtime error
Runtime error
| import argparse | |
| import pathlib | |
| from typing import Any, Dict | |
| import torch | |
| from accelerate import init_empty_weights | |
| from huggingface_hub import snapshot_download | |
| from transformers import T5EncoderModel, T5TokenizerFast | |
| from diffusers import ( | |
| AutoencoderKLCosmos, | |
| AutoencoderKLWan, | |
| CosmosTextToImagePipeline, | |
| CosmosTextToWorldPipeline, | |
| CosmosTransformer3DModel, | |
| EDMEulerScheduler, | |
| ) | |
| def remove_keys_(key: str, state_dict: Dict[str, Any]): | |
| state_dict.pop(key) | |
| def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: | |
| state_dict[new_key] = state_dict.pop(old_key) | |
| def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]): | |
| block_index = int(key.split(".")[1].removeprefix("block")) | |
| new_key = key | |
| old_prefix = f"blocks.block{block_index}" | |
| new_prefix = f"transformer_blocks.{block_index}" | |
| new_key = new_prefix + new_key.removeprefix(old_prefix) | |
| state_dict[new_key] = state_dict.pop(key) | |
| TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = { | |
| "t_embedder.1": "time_embed.t_embedder", | |
| "affline_norm": "time_embed.norm", | |
| ".blocks.0.block.attn": ".attn1", | |
| ".blocks.1.block.attn": ".attn2", | |
| ".blocks.2.block": ".ff", | |
| ".blocks.0.adaLN_modulation.1": ".norm1.linear_1", | |
| ".blocks.0.adaLN_modulation.2": ".norm1.linear_2", | |
| ".blocks.1.adaLN_modulation.1": ".norm2.linear_1", | |
| ".blocks.1.adaLN_modulation.2": ".norm2.linear_2", | |
| ".blocks.2.adaLN_modulation.1": ".norm3.linear_1", | |
| ".blocks.2.adaLN_modulation.2": ".norm3.linear_2", | |
| "to_q.0": "to_q", | |
| "to_q.1": "norm_q", | |
| "to_k.0": "to_k", | |
| "to_k.1": "norm_k", | |
| "to_v.0": "to_v", | |
| "layer1": "net.0.proj", | |
| "layer2": "net.2", | |
| "proj.1": "proj", | |
| "x_embedder": "patch_embed", | |
| "extra_pos_embedder": "learnable_pos_embed", | |
| "final_layer.adaLN_modulation.1": "norm_out.linear_1", | |
| "final_layer.adaLN_modulation.2": "norm_out.linear_2", | |
| "final_layer.linear": "proj_out", | |
| } | |
| TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = { | |
| "blocks.block": rename_transformer_blocks_, | |
| "logvar.0.freqs": remove_keys_, | |
| "logvar.0.phases": remove_keys_, | |
| "logvar.1.weight": remove_keys_, | |
| "pos_embedder.seq": remove_keys_, | |
| } | |
| TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = { | |
| "t_embedder.1": "time_embed.t_embedder", | |
| "t_embedding_norm": "time_embed.norm", | |
| "blocks": "transformer_blocks", | |
| "adaln_modulation_self_attn.1": "norm1.linear_1", | |
| "adaln_modulation_self_attn.2": "norm1.linear_2", | |
| "adaln_modulation_cross_attn.1": "norm2.linear_1", | |
| "adaln_modulation_cross_attn.2": "norm2.linear_2", | |
| "adaln_modulation_mlp.1": "norm3.linear_1", | |
| "adaln_modulation_mlp.2": "norm3.linear_2", | |
| "self_attn": "attn1", | |
| "cross_attn": "attn2", | |
| "q_proj": "to_q", | |
| "k_proj": "to_k", | |
| "v_proj": "to_v", | |
| "output_proj": "to_out.0", | |
| "q_norm": "norm_q", | |
| "k_norm": "norm_k", | |
| "mlp.layer1": "ff.net.0.proj", | |
| "mlp.layer2": "ff.net.2", | |
| "x_embedder.proj.1": "patch_embed.proj", | |
| # "extra_pos_embedder": "learnable_pos_embed", | |
| "final_layer.adaln_modulation.1": "norm_out.linear_1", | |
| "final_layer.adaln_modulation.2": "norm_out.linear_2", | |
| "final_layer.linear": "proj_out", | |
| } | |
| TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = { | |
| "accum_video_sample_counter": remove_keys_, | |
| "accum_image_sample_counter": remove_keys_, | |
| "accum_iteration": remove_keys_, | |
| "accum_train_in_hours": remove_keys_, | |
| "pos_embedder.seq": remove_keys_, | |
| "pos_embedder.dim_spatial_range": remove_keys_, | |
| "pos_embedder.dim_temporal_range": remove_keys_, | |
| "_extra_state": remove_keys_, | |
| } | |
| TRANSFORMER_CONFIGS = { | |
| "Cosmos-1.0-Diffusion-7B-Text2World": { | |
| "in_channels": 16, | |
| "out_channels": 16, | |
| "num_attention_heads": 32, | |
| "attention_head_dim": 128, | |
| "num_layers": 28, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (2.0, 1.0, 1.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": "learnable", | |
| }, | |
| "Cosmos-1.0-Diffusion-7B-Video2World": { | |
| "in_channels": 16 + 1, | |
| "out_channels": 16, | |
| "num_attention_heads": 32, | |
| "attention_head_dim": 128, | |
| "num_layers": 28, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (2.0, 1.0, 1.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": "learnable", | |
| }, | |
| "Cosmos-1.0-Diffusion-14B-Text2World": { | |
| "in_channels": 16, | |
| "out_channels": 16, | |
| "num_attention_heads": 40, | |
| "attention_head_dim": 128, | |
| "num_layers": 36, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (2.0, 2.0, 2.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": "learnable", | |
| }, | |
| "Cosmos-1.0-Diffusion-14B-Video2World": { | |
| "in_channels": 16 + 1, | |
| "out_channels": 16, | |
| "num_attention_heads": 40, | |
| "attention_head_dim": 128, | |
| "num_layers": 36, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (2.0, 2.0, 2.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": "learnable", | |
| }, | |
| "Cosmos-2.0-Diffusion-2B-Text2Image": { | |
| "in_channels": 16, | |
| "out_channels": 16, | |
| "num_attention_heads": 16, | |
| "attention_head_dim": 128, | |
| "num_layers": 28, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (1.0, 4.0, 4.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": None, | |
| }, | |
| "Cosmos-2.0-Diffusion-14B-Text2Image": { | |
| "in_channels": 16, | |
| "out_channels": 16, | |
| "num_attention_heads": 40, | |
| "attention_head_dim": 128, | |
| "num_layers": 36, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (1.0, 4.0, 4.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": None, | |
| }, | |
| "Cosmos-2.0-Diffusion-2B-Video2World": { | |
| "in_channels": 16 + 1, | |
| "out_channels": 16, | |
| "num_attention_heads": 16, | |
| "attention_head_dim": 128, | |
| "num_layers": 28, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (1.0, 3.0, 3.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": None, | |
| }, | |
| "Cosmos-2.0-Diffusion-14B-Video2World": { | |
| "in_channels": 16 + 1, | |
| "out_channels": 16, | |
| "num_attention_heads": 40, | |
| "attention_head_dim": 128, | |
| "num_layers": 36, | |
| "mlp_ratio": 4.0, | |
| "text_embed_dim": 1024, | |
| "adaln_lora_dim": 256, | |
| "max_size": (128, 240, 240), | |
| "patch_size": (1, 2, 2), | |
| "rope_scale": (20 / 24, 2.0, 2.0), | |
| "concat_padding_mask": True, | |
| "extra_pos_embed_type": None, | |
| }, | |
| } | |
| VAE_KEYS_RENAME_DICT = { | |
| "down.0": "down_blocks.0", | |
| "down.1": "down_blocks.1", | |
| "down.2": "down_blocks.2", | |
| "up.0": "up_blocks.2", | |
| "up.1": "up_blocks.1", | |
| "up.2": "up_blocks.0", | |
| ".block.": ".resnets.", | |
| "downsample": "downsamplers.0", | |
| "upsample": "upsamplers.0", | |
| "mid.block_1": "mid_block.resnets.0", | |
| "mid.attn_1.0": "mid_block.attentions.0", | |
| "mid.attn_1.1": "mid_block.temp_attentions.0", | |
| "mid.block_2": "mid_block.resnets.1", | |
| ".q.conv3d": ".to_q", | |
| ".k.conv3d": ".to_k", | |
| ".v.conv3d": ".to_v", | |
| ".proj_out.conv3d": ".to_out.0", | |
| ".0.conv3d": ".conv_s", | |
| ".1.conv3d": ".conv_t", | |
| "conv1.conv3d": "conv1", | |
| "conv2.conv3d": "conv2", | |
| "conv3.conv3d": "conv3", | |
| "nin_shortcut.conv3d": "conv_shortcut", | |
| "quant_conv.conv3d": "quant_conv", | |
| "post_quant_conv.conv3d": "post_quant_conv", | |
| } | |
| VAE_SPECIAL_KEYS_REMAP = { | |
| "wavelets": remove_keys_, | |
| "_arange": remove_keys_, | |
| "patch_size_buffer": remove_keys_, | |
| } | |
| VAE_CONFIGS = { | |
| "CV8x8x8-0.1": { | |
| "name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8", | |
| "diffusers_config": { | |
| "in_channels": 3, | |
| "out_channels": 3, | |
| "latent_channels": 16, | |
| "encoder_block_out_channels": (128, 256, 512, 512), | |
| "decode_block_out_channels": (256, 512, 512, 512), | |
| "attention_resolutions": (32,), | |
| "resolution": 1024, | |
| "num_layers": 2, | |
| "patch_size": 4, | |
| "patch_type": "haar", | |
| "scaling_factor": 1.0, | |
| "spatial_compression_ratio": 8, | |
| "temporal_compression_ratio": 8, | |
| "latents_mean": None, | |
| "latents_std": None, | |
| }, | |
| }, | |
| "CV8x8x8-1.0": { | |
| "name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", | |
| "diffusers_config": { | |
| "in_channels": 3, | |
| "out_channels": 3, | |
| "latent_channels": 16, | |
| "encoder_block_out_channels": (128, 256, 512, 512), | |
| "decode_block_out_channels": (256, 512, 512, 512), | |
| "attention_resolutions": (32,), | |
| "resolution": 1024, | |
| "num_layers": 2, | |
| "patch_size": 4, | |
| "patch_type": "haar", | |
| "scaling_factor": 1.0, | |
| "spatial_compression_ratio": 8, | |
| "temporal_compression_ratio": 8, | |
| "latents_mean": None, | |
| "latents_std": None, | |
| }, | |
| }, | |
| } | |
| def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: | |
| state_dict = saved_dict | |
| if "model" in saved_dict.keys(): | |
| state_dict = state_dict["model"] | |
| if "module" in saved_dict.keys(): | |
| state_dict = state_dict["module"] | |
| if "state_dict" in saved_dict.keys(): | |
| state_dict = state_dict["state_dict"] | |
| return state_dict | |
| def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True): | |
| PREFIX_KEY = "net." | |
| original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only)) | |
| if "Cosmos-1.0" in transformer_type: | |
| TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 | |
| TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 | |
| elif "Cosmos-2.0" in transformer_type: | |
| TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 | |
| TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 | |
| else: | |
| assert False | |
| with init_empty_weights(): | |
| config = TRANSFORMER_CONFIGS[transformer_type] | |
| transformer = CosmosTransformer3DModel(**config) | |
| for key in list(original_state_dict.keys()): | |
| new_key = key[:] | |
| if new_key.startswith(PREFIX_KEY): | |
| new_key = new_key.removeprefix(PREFIX_KEY) | |
| for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): | |
| new_key = new_key.replace(replace_key, rename_key) | |
| update_state_dict_(original_state_dict, key, new_key) | |
| for key in list(original_state_dict.keys()): | |
| for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): | |
| if special_key not in key: | |
| continue | |
| handler_fn_inplace(key, original_state_dict) | |
| transformer.load_state_dict(original_state_dict, strict=True, assign=True) | |
| return transformer | |
| def convert_vae(vae_type: str): | |
| model_name = VAE_CONFIGS[vae_type]["name"] | |
| snapshot_directory = snapshot_download(model_name, repo_type="model") | |
| directory = pathlib.Path(snapshot_directory) | |
| autoencoder_file = directory / "autoencoder.jit" | |
| mean_std_file = directory / "mean_std.pt" | |
| original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict() | |
| if mean_std_file.exists(): | |
| mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True) | |
| else: | |
| mean_std = (None, None) | |
| config = VAE_CONFIGS[vae_type]["diffusers_config"] | |
| config.update( | |
| { | |
| "latents_mean": mean_std[0].detach().cpu().numpy().tolist(), | |
| "latents_std": mean_std[1].detach().cpu().numpy().tolist(), | |
| } | |
| ) | |
| vae = AutoencoderKLCosmos(**config) | |
| for key in list(original_state_dict.keys()): | |
| new_key = key[:] | |
| for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): | |
| new_key = new_key.replace(replace_key, rename_key) | |
| update_state_dict_(original_state_dict, key, new_key) | |
| for key in list(original_state_dict.keys()): | |
| for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): | |
| if special_key not in key: | |
| continue | |
| handler_fn_inplace(key, original_state_dict) | |
| vae.load_state_dict(original_state_dict, strict=True, assign=True) | |
| return vae | |
| def save_pipeline_cosmos_1_0(args, transformer, vae): | |
| text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16) | |
| tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path) | |
| # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly. | |
| # So, the sigma_min values that is used is the default value of 0.002. | |
| scheduler = EDMEulerScheduler( | |
| sigma_min=0.002, | |
| sigma_max=80, | |
| sigma_data=0.5, | |
| sigma_schedule="karras", | |
| num_train_timesteps=1000, | |
| prediction_type="epsilon", | |
| rho=7.0, | |
| final_sigmas_type="sigma_min", | |
| ) | |
| pipe = CosmosTextToWorldPipeline( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=transformer, | |
| vae=vae, | |
| scheduler=scheduler, | |
| ) | |
| pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | |
| def save_pipeline_cosmos_2_0(args, transformer, vae): | |
| text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16) | |
| tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path) | |
| scheduler = EDMEulerScheduler( | |
| sigma_min=0.002, | |
| sigma_max=80, | |
| sigma_data=1.0, | |
| sigma_schedule="karras", | |
| num_train_timesteps=1000, | |
| prediction_type="epsilon", | |
| rho=7.0, | |
| final_sigmas_type="sigma_min", | |
| use_flow_sigmas=True, | |
| ) | |
| pipe = CosmosTextToImagePipeline( | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=transformer, | |
| vae=vae, | |
| scheduler=scheduler, | |
| ) | |
| pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) | |
| parser.add_argument( | |
| "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" | |
| ) | |
| parser.add_argument( | |
| "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE" | |
| ) | |
| parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b") | |
| parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b") | |
| parser.add_argument("--save_pipeline", action="store_true") | |
| parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") | |
| parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") | |
| return parser.parse_args() | |
| DTYPE_MAPPING = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| } | |
| if __name__ == "__main__": | |
| args = get_args() | |
| transformer = None | |
| dtype = DTYPE_MAPPING[args.dtype] | |
| if args.save_pipeline: | |
| assert args.transformer_ckpt_path is not None | |
| assert args.vae_type is not None | |
| assert args.text_encoder_path is not None | |
| assert args.tokenizer_path is not None | |
| if args.transformer_ckpt_path is not None: | |
| weights_only = "Cosmos-1.0" in args.transformer_type | |
| transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only) | |
| transformer = transformer.to(dtype=dtype) | |
| if not args.save_pipeline: | |
| transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | |
| if args.vae_type is not None: | |
| if "Cosmos-1.0" in args.transformer_type: | |
| vae = convert_vae(args.vae_type) | |
| else: | |
| vae = AutoencoderKLWan.from_pretrained( | |
| "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 | |
| ) | |
| if not args.save_pipeline: | |
| vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") | |
| if args.save_pipeline: | |
| if "Cosmos-1.0" in args.transformer_type: | |
| save_pipeline_cosmos_1_0(args, transformer, vae) | |
| elif "Cosmos-2.0" in args.transformer_type: | |
| save_pipeline_cosmos_2_0(args, transformer, vae) | |
| else: | |
| assert False | |