Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import tensorflow as tf | |
| import pickle | |
| import numpy as np | |
| from pathlib import Path | |
| import dnnlib | |
| from dnnlib import tflib | |
| import imageio | |
| import os | |
| import subprocess | |
| import random | |
| def check_gpu(): | |
| return tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None) | |
| def generate_image_from_projected_latents(latent_vector): | |
| images = Gs.components.synthesis.run(latent_vector, **Gs_kwargs) | |
| return images | |
| ## define video generation methods | |
| def ED_to_ES(latent_code): | |
| all_imgs = [] | |
| amounts_up = [i/25 for i in range(0,25)] | |
| amounts_down = [1-i/25 for i in range(1,26)] | |
| for amount_to_move in amounts_up: | |
| modified_latent_code = latent_code + latent_controls["time"]*amount_to_move | |
| images = generate_image_from_projected_latents(modified_latent_code) | |
| all_imgs.append(np.array(images[0])) | |
| for amount_to_move in amounts_down: | |
| modified_latent_code = latent_code + latent_controls["time"]*amount_to_move | |
| images = generate_image_from_projected_latents(modified_latent_code) | |
| all_imgs.append(np.array(images[0])) | |
| return np.array(all_imgs) | |
| def frame_to_frame(latent_code): | |
| modified_latent_code = np.copy(latent_code) | |
| full_video = [generate_image_from_projected_latents(modified_latent_code)] | |
| for i in range(49): | |
| modified_latent_code = modified_latent_code + latent_controls[f'{i}{i+1}'] | |
| ims = generate_image_from_projected_latents(modified_latent_code) | |
| full_video.append(ims) | |
| return np.array(full_video).squeeze() | |
| # Cache to avoid reloading the model every time | |
| def load_initial_setup(): | |
| stream = open('best_net.pkl', 'rb') | |
| tflib.init_tf() | |
| sess=tf.get_default_session() | |
| with stream: | |
| G, D, Gs = pickle.load(stream, encoding='latin1') | |
| noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
| Gs_kwargs = dnnlib.EasyDict() | |
| Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) | |
| Gs_kwargs.randomize_noise = False | |
| #load latent directions | |
| files = [x for x in Path('trajectories/').iterdir() if str(x).endswith('.npy')] | |
| latent_controls = {f.name[:-4]:np.load(f) for f in files} | |
| #select a random latent code | |
| noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] | |
| rnd = np.random.RandomState() | |
| tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) | |
| return Gs, Gs_kwargs, latent_controls, sess | |
| if __name__=="__main__": | |
| # Set the directory to the script's location | |
| dir_path = os.path.dirname(os.path.realpath(__file__)) | |
| heart_image_path = os.path.join(dir_path, 'heart.png') | |
| st.markdown(""" | |
| <style> | |
| .logo-test{ | |
| font-weight:700 !important; | |
| font-size:50px !important; | |
| color:#FF0000 !important; | |
| text-align: center; | |
| } | |
| </style> | |
| """,unsafe_allow_html=True) | |
| st.markdown('<p class="logo-test">GANcMRI</p>', unsafe_allow_html=True) | |
| # Description sliders | |
| st.markdown(""" | |
| This demo showcases GANcMRI: Synthetic cardiac MRI generation.Upon starting the demo or refreshing the page, a unique video will be automatically generated based on two methods described in our paper: ED-to-ES and Frame-to-Frame. These methods simulate cardiac function and movement in a realistic manner. The demo includes interactive sliders that allow you to adjust two parameters: | |
| 1. **Sphericity Index:** This slider controls the sphericity of the left ventricle in the generated video. [The sphericity index](https://www.cell.com/med/pdf/S2666-6340(23)00069-7.pdf) is a measure of how spherical (round) the left ventricle appears, which is an important aspect in assessing certain heart conditions. | |
| 2. **Left Ventricular Volume:** With this slider, you can modify the size of the left ventricle. | |
| """) | |
| sphericity_index = st.slider("Sphericity Index", -2., 3., 0.0) | |
| lv_area = st.slider("Left Ventricular Volume", -2., 3., 0.0) | |
| # Check if 'random_number' is already in the session state, if not, set a random number for seed | |
| if 'random_number' not in st.session_state: | |
| st.session_state.random_number = random.randint(0, 1000000) | |
| cols = st.columns(2) | |
| with cols[0]: | |
| st.caption('ED-to-ES') | |
| with cols[1]: | |
| st.caption('Frame-to-Frame') | |
| rnd = np.random.RandomState(st.session_state.random_number) | |
| Gs, Gs_kwargs, latent_controls, sess = load_initial_setup() | |
| with sess.as_default(): | |
| z = rnd.randn(1, *Gs.input_shape[1:]) | |
| random_img_latent_code = Gs.components.mapping.run(z,None) | |
| #make it be ED frame | |
| random_img_latent_code -= 0.7*latent_controls['time'] | |
| # Apply physiological adjustment | |
| adjusted_latent_code = np.copy(random_img_latent_code) | |
| adjusted_latent_code += sphericity_index * latent_controls['sphericity_index'] | |
| adjusted_latent_code += lv_area * latent_controls['lv_area'] | |
| ed_to_es_vid = ED_to_ES(adjusted_latent_code) | |
| f_to_f_vid = frame_to_frame(adjusted_latent_code) | |
| for idx,vid in enumerate([ed_to_es_vid, f_to_f_vid]): | |
| temp_video_path=f"output{idx}.mp4" | |
| writer=imageio.get_writer(temp_video_path, fps=20) | |
| for i in range(vid.shape[0]): | |
| frame = vid[i] | |
| writer.append_data(frame) | |
| writer.close() | |
| out_path = f"fixed_out{idx}.mp4" | |
| command = ["ffmpeg", "-i", temp_video_path, "-vcodec", "libx264", out_path] | |
| subprocess.run(command) | |
| with cols[idx]: | |
| st.video(out_path) | |
| os.remove(temp_video_path) | |
| os.remove(out_path) |