Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import ml_collections | |
| from torchvision.utils import save_image, make_grid | |
| import torch.nn.functional as F | |
| import einops | |
| import random | |
| import torchvision.transforms as standard_transforms | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download(repo_id="thu-ml/unidiffuser-v1", filename="autoencoder_kl.pth", local_dir='./models') | |
| hf_hub_download(repo_id="mespinosami/COP-GEN-Beta", filename="nnet_ema_114000.pth", local_dir='./models') | |
| import sys | |
| sys.path.append('./src/COP-GEN-Beta') | |
| import libs | |
| from dpm_solver_pp import DPM_Solver, NoiseScheduleVP | |
| from sample_n_triffuser import set_seed, stable_diffusion_beta_schedule, unpreprocess | |
| import utils | |
| from diffusers import AutoencoderKL | |
| from .Triffuser import * | |
| # Function to load model | |
| def load_model(device='cuda'): | |
| nnet = Triffuser(num_modalities=4) | |
| checkpoint = torch.load('models/nnet_ema_114000.pth', map_location='cuda') | |
| nnet.load_state_dict(checkpoint) | |
| nnet.to(device) | |
| nnet.eval() | |
| autoencoder = libs.autoencoder.get_model(pretrained_path = "models/autoencoder_kl.pth") | |
| autoencoder.to(device) | |
| autoencoder.eval() | |
| return nnet, autoencoder | |
| print('Loading COP-GEN-Beta model...') | |
| nnet, autoencoder = load_model() | |
| to_PIL = standard_transforms.ToPILImage() | |
| print('[DONE]') | |
| def get_config(generate_modalities, condition_modalities, seed, num_inference_steps=50): | |
| config = ml_collections.ConfigDict() | |
| config.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| config.seed = seed | |
| config.n_samples = 1 | |
| config.z_shape = (4, 32, 32) # Shape of the latent vectors | |
| config.sample = { | |
| 'sample_steps': num_inference_steps, | |
| 'algorithm': "dpm_solver", | |
| } | |
| # Model config | |
| config.num_modalities = 4 # 4 modalities: DEM, S1RTC, S2L1C, S2L2A | |
| config.modalities = ['dem', 's1_rtc', 's2_l1c', 's2_l2a'] | |
| # Network config | |
| config.nnet = { | |
| 'name': 'triffuser_multi_post_ln', | |
| 'img_size': 32, | |
| 'in_chans': 4, | |
| 'patch_size': 2, | |
| 'embed_dim': 1024, | |
| 'depth': 20, | |
| 'num_heads': 16, | |
| 'mlp_ratio': 4, | |
| 'qkv_bias': False, | |
| 'pos_drop_rate': 0., | |
| 'drop_rate': 0., | |
| 'attn_drop_rate': 0., | |
| 'mlp_time_embed': False, | |
| 'num_modalities': 4, | |
| 'use_checkpoint': True, | |
| } | |
| # Parse generate and condition modalities | |
| config.generate_modalities = generate_modalities | |
| config.generate_modalities = sorted(config.generate_modalities, key=lambda x: config.modalities.index(x)) | |
| config.condition_modalities = condition_modalities if condition_modalities else [] | |
| config.condition_modalities = sorted(config.condition_modalities, key=lambda x: config.modalities.index(x)) | |
| config.generate_modalities_mask = [mod in config.generate_modalities for mod in config.modalities] | |
| config.condition_modalities_mask = [mod in config.condition_modalities for mod in config.modalities] | |
| # Validate modalities | |
| valid_modalities = {'s2_l1c', 's2_l2a', 's1_rtc', 'dem'} | |
| for mod in config.generate_modalities + config.condition_modalities: | |
| if mod not in valid_modalities: | |
| raise ValueError(f"Invalid modality: {mod}. Must be one of {valid_modalities}") | |
| # Check that generate and condition modalities don't overlap | |
| if set(config.generate_modalities) & set(config.condition_modalities): | |
| raise ValueError("Generate and condition modalities must be different") | |
| # Default data paths | |
| config.nnet_path = 'models/nnet_ema_114000.pth' | |
| #config.autoencoder = {"pretrained_path": "assets/stable-diffusion/autoencoder_kl_ema.pth"} | |
| return config | |
| # Function to prepare image for inference | |
| def prepare_images(images): | |
| transforms = standard_transforms.Compose([ | |
| standard_transforms.ToTensor(), | |
| standard_transforms.Normalize(mean=(0.5,), std=(0.5,)) | |
| ]) | |
| img_tensors = [] | |
| for img in images: | |
| img_tensors.append(transforms(img)) # Add batch dimension | |
| return img_tensors | |
| def run_inference(config, nnet, autoencoder, img_tensors): | |
| set_seed(config.seed) | |
| img_tensors = [tensor.to(config.device) for tensor in img_tensors] | |
| # Create a context tensor for all modalities | |
| img_contexts = torch.randn(config.num_modalities, 1, 2 * config.z_shape[0], | |
| config.z_shape[1], config.z_shape[2], device=config.device) | |
| with torch.no_grad(): | |
| # Encode the input images with autoencoder | |
| z_conds = [autoencoder.encode_moments(tensor.unsqueeze(0)) for tensor in img_tensors] | |
| # Create mapping of conditional modalities indices to the encoded inputs | |
| cond_indices = [i for i, is_cond in enumerate(config.condition_modalities_mask) if is_cond] | |
| # Check if we have the right number of inputs | |
| if len(cond_indices) != len(z_conds): | |
| raise ValueError(f"Number of conditioning modalities ({len(cond_indices)}) must match number of input images ({len(z_conds)})") | |
| # Assign each encoded input to the corresponding modality | |
| for i, z_cond in zip(cond_indices, z_conds): | |
| img_contexts[i] = z_cond | |
| # Sample values from the distribution (mean and variance) | |
| z_imgs = torch.stack([autoencoder.sample(img_context) for img_context in img_contexts]) | |
| # Generate initial noise for the modalities being generated | |
| _z_init = torch.randn(len(config.generate_modalities), 1, *z_imgs[0].shape[1:], device=config.device) | |
| def combine_joint(z_list): | |
| """Combine individual modality tensors into a single concatenated tensor""" | |
| return torch.concat([einops.rearrange(z_i, 'B C H W -> B (C H W)') for z_i in z_list], dim=-1) | |
| def split_joint(x, z_imgs, config): | |
| """ | |
| Split the combined tensor back into individual modality tensors | |
| and arrange them according to the full set of modalities | |
| """ | |
| C, H, W = config.z_shape | |
| z_dim = C * H * W | |
| z_generated = x.split([z_dim] * len(config.generate_modalities), dim=1) | |
| z_generated = {modality: einops.rearrange(z_i, 'B (C H W) -> B C H W', C=C, H=H, W=W) | |
| for z_i, modality in zip(z_generated, config.generate_modalities)} | |
| z = [] | |
| for i, modality in enumerate(config.modalities): | |
| if modality in config.generate_modalities: # Modalities that are being denoised | |
| z.append(z_generated[modality]) | |
| elif modality in config.condition_modalities: # Modalities that are being conditioned on | |
| z.append(z_imgs[i]) | |
| else: # Modalities that are ignored | |
| z.append(torch.randn(x.shape[0], C, H, W, device=config.device)) | |
| return z | |
| _x_init = combine_joint(_z_init) # Initial tensor for the modalities being generated | |
| _betas = stable_diffusion_beta_schedule() | |
| N = len(_betas) | |
| def model_fn(x, t_continuous): | |
| t = t_continuous * N | |
| # Create timesteps for each modality based on the generate mask | |
| timesteps = [t if mask else torch.zeros_like(t) for mask in config.generate_modalities_mask] | |
| # Split the input into a list of tensors for all modalities | |
| z = split_joint(x, z_imgs, config) | |
| # Call the network with the right format | |
| z_out = nnet(z, t_imgs=timesteps) | |
| # Select only the generated modalities for the denoising process | |
| z_out_generated = [z_out[i] | |
| for i, modality in enumerate(config.modalities) | |
| if modality in config.generate_modalities] | |
| # Combine the outputs back into a single tensor | |
| return combine_joint(z_out_generated) | |
| # Sample using the DPM-Solver with exact parameters from sample_n_triffuser.py | |
| noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=config.device).float()) | |
| dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False) | |
| # Generate samples | |
| with torch.no_grad(): | |
| with torch.autocast(device_type=config.device): | |
| x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.) | |
| # Split the result back into individual modality tensors | |
| _zs = split_joint(x, z_imgs, config) | |
| # Replace conditional modalities with the original images | |
| for i, mask in enumerate(config.condition_modalities_mask): | |
| if mask: | |
| _zs[i] = z_imgs[i] | |
| # Decode and unprocess the generated samples | |
| generated_samples = [] | |
| for i, modality in enumerate(config.modalities): | |
| if modality in config.generate_modalities: | |
| sample = autoencoder.decode(_zs[i]) # Decode the latent representation | |
| sample = unpreprocess(sample) # Unpreprocess to [0, 1] range | |
| generated_samples.append((modality, sample)) | |
| return generated_samples | |
| def custom_inference(images, generate_modalities, condition_modalities, num_inference_steps, seed=None): | |
| """ | |
| Run custom inference with user-specified parameters | |
| Args: | |
| generate_modalities: List of modalities to generate | |
| condition_modalities: List of modalities to condition on | |
| image_paths: Path to conditioning image or list of paths (ordered to match condition_modalities) | |
| Returns: | |
| Dict mapping modality names to generated tensors | |
| """ | |
| if seed is None: | |
| seed = random.randint(0, int(1e8)) | |
| img_tensors = prepare_images(images) | |
| config = get_config(generate_modalities, condition_modalities, seed=seed) | |
| config.sample.sample_steps = num_inference_steps | |
| generated_samples = run_inference(config, nnet, autoencoder, img_tensors) | |
| results = {modality: tensor for modality, tensor in generated_samples} | |
| return results | |
| def generate_output(s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, seed_number, ignore_seed): | |
| seed = seed_number if not ignore_seed else None | |
| s2l2a_active = s2l2a_input is not None | |
| s2l1c_active = s2l1c_input is not None | |
| s1rtc_active = s1rtc_input is not None | |
| dem_active = dem_input is not None | |
| if s2l2a_active and s2l1c_active and s1rtc_active and dem_active: | |
| gr.Warning("You need to remove some of the inputs that you would like to generate. If all modalities are known, there is nothing to generate.") | |
| return s2l1c_input, s2l2a_input, s1rtc_input, dem_input | |
| # Instead of collecting in UI order, create ordered dictionaries | |
| input_images = {} | |
| if s2l1c_active: | |
| input_images['s2_l1c'] = s2l1c_input | |
| if s2l2a_active: | |
| input_images['s2_l2a'] = s2l2a_input | |
| if s1rtc_active: | |
| input_images['s1_rtc'] = s1rtc_input | |
| if dem_active: | |
| input_images['dem'] = dem_input | |
| condition_modalities = list(input_images.keys()) | |
| # Sort modalities and collect images in the same order | |
| sorted_modalities = sorted(condition_modalities, key=lambda x: ['dem', 's1_rtc', 's2_l1c', 's2_l2a'].index(x)) | |
| sorted_images = [input_images[mod] for mod in sorted_modalities] | |
| imgs_out = custom_inference( | |
| images=sorted_images, | |
| generate_modalities=[el for el in ['s2_l1c', 's2_l2a', 's1_rtc', 'dem'] if el not in condition_modalities], | |
| condition_modalities=sorted_modalities, | |
| num_inference_steps=num_inference_steps_slider, | |
| seed=seed | |
| ) | |
| output = [] | |
| # Collect outputs | |
| for modality in sorted_modalities: | |
| if modality in input_images: | |
| output.append(input_images[modality]) | |
| else: | |
| output.append(to_PIL(imgs_out[modality][0])) | |
| return output | |