import gradio as gr import time import sys import subprocess import time from pathlib import Path import hydra from omegaconf import DictConfig, OmegaConf from omegaconf.omegaconf import open_dict from utils.print_utils import cyan from utils.ckpt_utils import download_latest_checkpoint, is_run_id from utils.cluster_utils import submit_slurm_job from utils.distributed_utils import is_rank_zero import numpy as np import torch from datasets.video.minecraft_video_dataset import * import torchvision.transforms as transforms import cv2 import subprocess from PIL import Image from datetime import datetime import spaces from algorithms.worldmem import WorldMemMinecraft from huggingface_hub import hf_hub_download import tempfile torch.set_float32_matmul_precision("high") ACTION_KEYS = [ "inventory", "ESC", "hotbar.1", "hotbar.2", "hotbar.3", "hotbar.4", "hotbar.5", "hotbar.6", "hotbar.7", "hotbar.8", "hotbar.9", "forward", "back", "left", "right", "cameraY", "cameraX", "jump", "sneak", "sprint", "swapHands", "attack", "use", "pickItem", "drop", ] # Mapping of input keys to action names KEY_TO_ACTION = { "Q": ("forward", 1), "E": ("back", 1), "W": ("cameraY", -1), "S": ("cameraY", 1), "A": ("cameraX", -1), "D": ("cameraX", 1), "U": ("drop", 1), "N": ("noop", 1), "1": ("hotbar.1", 1), } example_images = [ ["1", "assets/ice_plains.png", "turn right+go backward+look up+turn left+look down+turn right+go forward+turn left", 20, 3, 8], ["2", "assets/place.png", "put item+go backward+put item+go backward+go around", 20, 3, 8], ["3", "assets/rain_sunflower_plains.png", "turn right+look up+turn right+look down+turn left+go backward+turn left", 20, 3, 8], ["4", "assets/desert.png", "turn 360 degree+turn right+go forward+turn left", 20, 3, 8], ] def load_custom_checkpoint(algo, checkpoint_path): hf_ckpt = str(checkpoint_path).split('/') repo_id = '/'.join(hf_ckpt[:2]) file_name = '/'.join(hf_ckpt[2:]) model_path = hf_hub_download(repo_id=repo_id, filename=file_name) ckpt = torch.load(model_path, map_location=torch.device('cpu')) algo.load_state_dict(ckpt['state_dict'], strict=False) def parse_input_to_tensor(input_str): """ Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation of the corresponding action key. Args: input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS"). Returns: torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action. """ # Get the length of the input sequence seq_len = len(input_str) # Initialize a zero tensor of shape (seq_len, 25) action_tensor = torch.zeros((seq_len, 25)) # Iterate through the input string and update the corresponding positions for i, char in enumerate(input_str): action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity if action and action in ACTION_KEYS: index = ACTION_KEYS.index(action) action_tensor[i, index] = value # Set the corresponding action index to 1 return action_tensor def load_image_as_tensor(image_path: str) -> torch.Tensor: """ Load an image and convert it to a 0-1 normalized tensor. Args: image_path (str): Path to the image file. Returns: torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1]. """ if isinstance(image_path, str): image = Image.open(image_path).convert("RGB") # Ensure it's RGB else: image = image_path transform = transforms.Compose([ transforms.ToTensor(), # Converts to tensor and normalizes to [0,1] ]) return transform(image) def run_local(cfg: DictConfig): # delay some imports in case they are not needed in non-local envs for submission from experiments import build_experiment # Get yaml names hydra_cfg = hydra.core.hydra_config.HydraConfig.get() cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) with open_dict(cfg): if cfg_choice["experiment"] is not None: cfg.experiment._name = cfg_choice["experiment"] if cfg_choice["dataset"] is not None: cfg.dataset._name = cfg_choice["dataset"] if cfg_choice["algorithm"] is not None: cfg.algorithm._name = cfg_choice["algorithm"] # launch experiment experiment = build_experiment(cfg, None, None) return experiment.exec_interactive(cfg.experiment.tasks[0]) def enable_amp(model, precision="16-mixed"): original_forward = model.forward def amp_forward(*args, **kwargs): with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16): return original_forward(*args, **kwargs) model.forward = amp_forward return model memory_frames = [] input_history = "" ICE_PLAINS_IMAGE = "assets/ice_plains.png" DESERT_IMAGE = "assets/desert.png" SAVANNA_IMAGE = "assets/savanna.png" PLAINS_IMAGE = "assets/plans.png" PLACE_IMAGE = "assets/place.png" SUNFLOWERS_IMAGE = "assets/sunflower_plains.png" SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png" device = torch.device('cuda') def save_video(frames, path="output.mp4", fps=10): h, w, _ = frames[0].shape out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h)) for frame in frames: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() ffmpeg_cmd = [ "ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path ] subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return path cfg = OmegaConf.load("configurations/huggingface.yaml") worldmem = WorldMemMinecraft(cfg) load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path) load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path) load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path) worldmem.to("cuda").eval() worldmem = enable_amp(worldmem, precision="16-mixed") actions = np.zeros((1, 25), dtype=np.float32) poses = np.zeros((1, 5), dtype=np.float32) def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx): return 5 * len(action) if self_actions is not None else 5 @spaces.GPU(duration=get_duration_single_image_to_long_video) def run_interactive(first_frame, action, first_pose, device, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx): new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame, action, first_pose, device=device, self_frames=self_frames, self_actions=self_actions, self_poses=self_poses, self_memory_c2w=self_memory_c2w, self_frame_idx=self_frame_idx) return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx def set_denoising_steps(denoising_steps, sampling_timesteps_state): worldmem.sampling_timesteps = denoising_steps worldmem.diffusion_model.sampling_timesteps = denoising_steps sampling_timesteps_state = denoising_steps print("set denoising steps to", worldmem.sampling_timesteps) return sampling_timesteps_state def set_context_length(context_length, sampling_context_length_state): worldmem.n_tokens = context_length sampling_context_length_state = context_length print("set context length to", worldmem.n_tokens) return sampling_context_length_state def set_memory_length(memory_length, sampling_memory_length_state): worldmem.condition_similar_length = memory_length sampling_memory_length_state = memory_length print("set memory length to", worldmem.condition_similar_length) return sampling_memory_length_state def generate(keys, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx): input_actions = parse_input_to_tensor(keys) if self_frames is None: new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], actions[0], poses[0], device=device, self_frames=self_frames, self_actions=self_actions, self_poses=self_poses, self_memory_c2w=self_memory_c2w, self_frame_idx=self_frame_idx) new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], input_actions, None, device=device, self_frames=self_frames, self_actions=self_actions, self_poses=self_poses, self_memory_c2w=self_memory_c2w, self_frame_idx=self_frame_idx) memory_frames = np.concatenate([memory_frames, new_frame[:,0]]) out_video = memory_frames.transpose(0,2,3,1) out_video = np.clip(out_video, a_min=0.0, a_max=1.0) out_video = (out_video * 255).astype(np.uint8) temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name save_video(out_video, temporal_video_path) now = datetime.now() folder_name = now.strftime("%Y-%m-%d_%H-%M-%S") folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name) os.makedirs(folder_path, exist_ok=True) input_history += keys data_dict = { "input_history": input_history, "memory_frames": memory_frames, "self_frames": self_frames, "self_actions": self_actions, "self_poses": self_poses, "self_memory_c2w": self_memory_c2w, "self_frame_idx": self_frame_idx, } np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict) return out_video[-1], temporal_video_path, input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx def reset(selected_image): self_frames = None self_poses = None self_actions = None self_memory_c2w = None self_frame_idx = None memory_frames = load_image_as_tensor(selected_image).numpy()[None] input_history = "" new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0], actions[0], poses[0], device=device, self_frames=self_frames, self_actions=self_actions, self_poses=self_poses, self_memory_c2w=self_memory_c2w, self_frame_idx=self_frame_idx) return input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx def on_image_click(selected_image): input_history, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = reset(selected_image) return input_history, selected_image, selected_image, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx def set_memory(examples_case, image_display, log_output, slider_denoising_step, slider_context_length, slider_memory_length): if examples_case == '1': data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-01-49/data_bundle.npz") input_history = data_bundle['input_history'].item() memory_frames = data_bundle['memory_frames'] self_frames = data_bundle['self_frames'] self_actions = data_bundle['self_actions'] self_poses = data_bundle['self_poses'] self_memory_c2w = data_bundle['self_memory_c2w'] self_frame_idx = data_bundle['self_frame_idx'] elif examples_case == '2': data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-42-04/data_bundle.npz") input_history = data_bundle['input_history'].item() memory_frames = data_bundle['memory_frames'] self_frames = data_bundle['self_frames'] self_actions = data_bundle['self_actions'] self_poses = data_bundle['self_poses'] self_memory_c2w = data_bundle['self_memory_c2w'] self_frame_idx = data_bundle['self_frame_idx'] elif examples_case == '3': data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-12_10-56-57/data_bundle.npz") input_history = data_bundle['input_history'].item() memory_frames = data_bundle['memory_frames'] self_frames = data_bundle['self_frames'] self_actions = data_bundle['self_actions'] self_poses = data_bundle['self_poses'] self_memory_c2w = data_bundle['self_memory_c2w'] self_frame_idx = data_bundle['self_frame_idx'] elif examples_case == '4': data_bundle = np.load("/mnt/xiaozeqi/worldmem/output_material/2025-04-11_16-07-19/data_bundle.npz") input_history = data_bundle['input_history'].item() memory_frames = data_bundle['memory_frames'] self_frames = data_bundle['self_frames'] self_actions = data_bundle['self_actions'] self_poses = data_bundle['self_poses'] self_memory_c2w = data_bundle['self_memory_c2w'] self_frame_idx = data_bundle['self_frame_idx'] out_video = memory_frames.transpose(0,2,3,1) out_video = np.clip(out_video, a_min=0.0, a_max=1.0) out_video = (out_video * 255).astype(np.uint8) temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name save_video(out_video, temporal_video_path) return input_history, out_video[-1], temporal_video_path, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx css = """ h1 { text-align: center; display:block; } """ def on_select(evt: gr.SelectData): selected_index = evt.index return examples[selected_index] with gr.Blocks(css=css) as demo: gr.Markdown( """ # WORLDMEM: Long-term Consistent World Generation with Memory """ ) #
# # # # # # # # # # # # # # # # #
example_actions = {"turn left + turn right": "AAAAAAAAAAAADDDDDDDDDDDD", "turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA", "turn right+go backward+look up+turn left+look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "turn right+go forward+turn left": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD", "turn right+look up+turn right+look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS", "put item+go backward+put item+go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"} selected_image = gr.State(ICE_PLAINS_IMAGE) with gr.Row(variant="panel"): video_display = gr.Video(autoplay=True, loop=True) image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame") with gr.Row(variant="panel"): with gr.Column(scale=2): input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1) log_output = gr.Textbox(label="History Log", interactive=False) gr.Markdown("### Action sequence examples.") with gr.Row(): buttons = [] for action_key in list(example_actions.keys())[:2]: with gr.Column(scale=len(action_key)): buttons.append(gr.Button(action_key)) with gr.Row(): for action_key in list(example_actions.keys())[2:4]: with gr.Column(scale=len(action_key)): buttons.append(gr.Button(action_key)) with gr.Row(): for action_key in list(example_actions.keys())[4:6]: with gr.Column(scale=len(action_key)): buttons.append(gr.Button(action_key)) with gr.Column(scale=1): slider_denoising_step = gr.Slider(minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1, label="Denoising Steps") slider_context_length = gr.Slider(minimum=2, maximum=10, value=worldmem.n_tokens, step=1, label="Context Length") slider_memory_length = gr.Slider(minimum=4, maximum=16, value=worldmem.condition_similar_length, step=1, label="Memory Length") submit_button = gr.Button("Generate") reset_btn = gr.Button("Reset") sampling_timesteps_state = gr.State(worldmem.sampling_timesteps) sampling_context_length_state = gr.State(worldmem.n_tokens) sampling_memory_length_state = gr.State(worldmem.condition_similar_length) memory_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy()) self_frames = gr.State() self_actions = gr.State() self_poses = gr.State() self_memory_c2w = gr.State() self_frame_idx = gr.State() def set_action(action): return action # gr.Markdown("### Action sequence examples.") for button, action_key in zip(buttons, list(example_actions.keys())): button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box) gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.") with gr.Row(): image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains") image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert") image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna") image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains") image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains") image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place") gr.Markdown("### Click the examples below for a quick review, and continue generating based on them.") example_case = gr.Textbox(label="Case", visible=False) image_output = gr.Image(visible=False) # gr.Examples(examples=example_images, # inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length], # fn=set_memory, # outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], # cache_examples=True # ) examples = gr.Examples( examples=example_images, inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length], cache_examples=False ) example_case.change( fn=set_memory, inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_length], outputs=[log_output, image_display, video_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx] ) gr.Markdown( """ ## Instructions & Notes: 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin. 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel. 3. Click **"Reset"** to clear the current sequence and start fresh. 4. Action sequences can be composed using the following keys: - W: turn up - S: turn down - A: turn left - D: turn right - Q: move forward - E: move backward - N: no-op (do nothing) - U: use item 5. Higher denoising steps produce more detailed results but take longer. 20 steps is a good balance between quality and speed. The same applies to context and memory length. 6. For faster performance, we recommend running the demo locally (~1s/frame on H100 vs ~5s on Spaces). 7. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()! 8. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **zeqixiao1@gmail.com**. """ ) # input_box.submit(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output]) submit_button.click(generate, inputs=[input_box, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx], outputs=[image_display, video_display, log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, memory_frames, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx]) slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state) slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state) slider_memory_length.change(fn=set_memory_length, inputs=[slider_memory_length, sampling_memory_length_state], outputs=sampling_memory_length_state) demo.launch()