Spaces:
Runtime error
Runtime error
| import torch | |
| from diffusers.loaders import AttnProcsLayers | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from modules.beats.BEATs import BEATs, BEATsConfig | |
| from modules.AudioToken.embedder import FGAEmbedder | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.models.attention_processor import LoRAAttnProcessor | |
| from diffusers import StableDiffusionPipeline | |
| import numpy as np | |
| import gradio as gr | |
| from scipy import signal | |
| class AudioTokenWrapper(torch.nn.Module): | |
| """Simple wrapper module for Stable Diffusion that holds all the models together""" | |
| def __init__( | |
| self, | |
| lora, | |
| device, | |
| ): | |
| super().__init__() | |
| # Load scheduler and models | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", subfolder="tokenizer" | |
| ) | |
| self.text_encoder = CLIPTextModel.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", revision=None | |
| ) | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None | |
| ) | |
| self.vae = AutoencoderKL.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", subfolder="vae", revision=None | |
| ) | |
| checkpoint = torch.load( | |
| 'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt') | |
| cfg = BEATsConfig(checkpoint['cfg']) | |
| self.aud_encoder = BEATs(cfg) | |
| self.aud_encoder.load_state_dict(checkpoint['model']) | |
| self.aud_encoder.predictor = None | |
| input_size = 768 * 3 | |
| self.embedder = FGAEmbedder(input_size=input_size, output_size=768) | |
| self.vae.eval() | |
| self.unet.eval() | |
| self.text_encoder.eval() | |
| self.aud_encoder.eval() | |
| if lora: | |
| # Set correct lora layers | |
| lora_attn_procs = {} | |
| for name in self.unet.attn_processors.keys(): | |
| cross_attention_dim = None if name.endswith( | |
| "attn1.processor") else self.unet.config.cross_attention_dim | |
| if name.startswith("mid_block"): | |
| hidden_size = self.unet.config.block_out_channels[-1] | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| hidden_size = self.unet.config.block_out_channels[block_id] | |
| lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, | |
| cross_attention_dim=cross_attention_dim) | |
| self.unet.set_attn_processor(lora_attn_procs) | |
| self.lora_layers = AttnProcsLayers(self.unet.attn_processors) | |
| self.lora_layers.eval() | |
| lora_layers_learned_embeds = 'models/lora_layers_learned_embeds.bin' | |
| self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device)) | |
| self.unet.load_attn_procs(lora_layers_learned_embeds) | |
| self.embedder.eval() | |
| embedder_learned_embeds = 'models/embedder_learned_embeds.bin' | |
| self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device)) | |
| self.placeholder_token = '<*>' | |
| num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token) | |
| if num_added_tokens == 0: | |
| raise ValueError( | |
| f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different" | |
| " `placeholder_token` that is not already in the tokenizer." | |
| ) | |
| self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(self.placeholder_token) | |
| # Resize the token embeddings as we are adding new special tokens to the tokenizer | |
| self.text_encoder.resize_token_embeddings(len(self.tokenizer)) | |
| def greet(audio): | |
| sample_rate, audio = audio | |
| audio = audio.astype(np.float32, order='C') / 32768.0 | |
| desired_sample_rate = 16000 | |
| if audio.ndim == 2: | |
| audio = audio.sum(axis=1) / 2 | |
| if sample_rate != desired_sample_rate: | |
| # Calculate the resampling ratio | |
| resample_ratio = desired_sample_rate / sample_rate | |
| # Determine the new length of the audio data after downsampling | |
| new_length = int(len(audio) * resample_ratio) | |
| # Downsample the audio data using resample | |
| audio = signal.resample(audio, new_length) | |
| weight_dtype = torch.float32 | |
| prompt = 'a photo of <*>' | |
| audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype) | |
| if audio_values.ndim == 1: | |
| audio_values = torch.unsqueeze(audio_values, dim=0) | |
| aud_features = model.aud_encoder.extract_features(audio_values)[1] | |
| audio_token = model.embedder(aud_features) | |
| token_embeds = model.text_encoder.get_input_embeddings().weight.data | |
| token_embeds[model.placeholder_token_id] = audio_token.clone() | |
| pipeline = StableDiffusionPipeline.from_pretrained( | |
| "CompVis/stable-diffusion-v1-4", | |
| tokenizer=model.tokenizer, | |
| text_encoder=model.text_encoder, | |
| vae=model.vae, | |
| unet=model.unet, | |
| ).to(device) | |
| image = pipeline(prompt, num_inference_steps=40, guidance_scale=7.5).images[0] | |
| return image | |
| if __name__ == "__main__": | |
| lora = False | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = AudioTokenWrapper(lora, device) | |
| model = model.to(device) | |
| description = """<p> | |
| This is a demo of <a href='https://pages.cs.huji.ac.il/adiyoss-lab/AudioToken' target='_blank'>AudioToken: Adaptation of Text-Conditioned Diffusion Models for Audio-to-Image Generation</a>.<br><br> | |
| A novel method utilizing latent diffusion models trained for text-to-image-generation to generate images conditioned on audio recordings. Using a pre-trained audio encoding model, the proposed method encodes audio into a new token, which can be considered as an adaptation layer between the audio and text representations.<br><br> | |
| For more information, please see the original <a href='https://arxiv.org/abs/2305.13050' target='_blank'>paper</a> and <a href='https://github.com/guyyariv/AudioToken' target='_blank'>repo</a>. | |
| </p>""" | |
| examples = [ | |
| # ["assets/train.wav"], | |
| ["assets/dog barking.wav"], | |
| # ["assets/airplane taking off.wav"], | |
| # ["assets/electric guitar.wav"], | |
| # ["assets/female sings.wav"], | |
| ] | |
| demo = gr.Interface( | |
| fn=greet, | |
| inputs="audio", | |
| outputs="image", | |
| title='AudioToken', | |
| description=description, | |
| examples=examples | |
| ) | |
| demo.launch() | |