Spaces:
Build error
Build error
| ''' | |
| conda activate zero123 | |
| cd zero123 | |
| python gradio_new.py 0 | |
| ''' | |
| import diffusers # 0.12.1 | |
| import math | |
| import fire | |
| import gradio as gr | |
| import lovely_numpy | |
| import lovely_tensors | |
| import numpy as np | |
| import os | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import rich | |
| import sys | |
| import time | |
| import torch | |
| from contextlib import nullcontext | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from einops import rearrange | |
| from functools import partial | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.util import create_carvekit_interface, load_and_preprocess, instantiate_from_config | |
| from lovely_numpy import lo | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from rich import print | |
| from transformers import AutoFeatureExtractor | |
| from torch import autocast | |
| from torchvision import transforms | |
| _SHOW_DESC = True | |
| _SHOW_INTERMEDIATE = False | |
| # _SHOW_INTERMEDIATE = True | |
| _GPU_INDEX = 0 | |
| # _GPU_INDEX = 2 | |
| # _TITLE = 'Zero-Shot Control of Camera Viewpoints within a Single Image' | |
| _TITLE = 'Zero-1-to-3: Zero-shot One Image to 3D Object' | |
| # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion. | |
| _DESCRIPTION = ''' | |
| This live demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image. | |
| It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/pdf/2303.11328.pdf) if you want to learn more about the method! | |
| Note that this model is not intended for images of humans or faces, and is unlikely to work well for them. | |
| ''' | |
| _ARTICLE = 'See uses.md' | |
| def load_model_from_config(config, ckpt, device, verbose=False): | |
| print(f'Loading model from {ckpt}') | |
| pl_sd = torch.load(ckpt, map_location='cpu') | |
| if 'global_step' in pl_sd: | |
| print(f'Global Step: {pl_sd["global_step"]}') | |
| sd = pl_sd['state_dict'] | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print('missing keys:') | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print('unexpected keys:') | |
| print(u) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, | |
| ddim_eta, x, y, z): | |
| precision_scope = autocast if precision == 'autocast' else nullcontext | |
| with precision_scope('cuda'): | |
| with model.ema_scope(): | |
| c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) | |
| T = torch.tensor([math.radians(x), math.sin( | |
| math.radians(y)), math.cos(math.radians(y)), z]) | |
| T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) | |
| c = torch.cat([c, T], dim=-1) | |
| c = model.cc_projection(c) | |
| cond = {} | |
| cond['c_crossattn'] = [c] | |
| c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach() | |
| cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach() | |
| .repeat(n_samples, 1, 1, 1)] | |
| if scale != 1.0: | |
| uc = {} | |
| uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] | |
| uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] | |
| else: | |
| uc = None | |
| shape = [4, h // 8, w // 8] | |
| samples_ddim, _ = sampler.sample(S=ddim_steps, | |
| conditioning=cond, | |
| batch_size=n_samples, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=scale, | |
| unconditional_conditioning=uc, | |
| eta=ddim_eta, | |
| x_T=None) | |
| print(samples_ddim.shape) | |
| # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() | |
| class CameraVisualizer: | |
| def __init__(self, gradio_plot): | |
| self._gradio_plot = gradio_plot | |
| self._fig = None | |
| self._polar = 0.0 | |
| self._azimuth = 0.0 | |
| self._radius = 0.0 | |
| self._raw_image = None | |
| self._8bit_image = None | |
| self._image_colorscale = None | |
| def polar_change(self, value): | |
| self._polar = value | |
| # return self.update_figure() | |
| def azimuth_change(self, value): | |
| self._azimuth = value | |
| # return self.update_figure() | |
| def radius_change(self, value): | |
| self._radius = value | |
| # return self.update_figure() | |
| def encode_image(self, raw_image): | |
| ''' | |
| :param raw_image (H, W, 3) array of uint8 in [0, 255]. | |
| ''' | |
| # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot | |
| dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB') | |
| idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3)) | |
| self._raw_image = raw_image | |
| self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None) | |
| # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert( | |
| # 'P', palette='WEB', dither=None) | |
| self._image_colorscale = [ | |
| [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)] | |
| # return self.update_figure() | |
| def update_figure(self): | |
| fig = go.Figure() | |
| if self._raw_image is not None: | |
| (H, W, C) = self._raw_image.shape | |
| x = np.zeros((H, W)) | |
| (y, z) = np.meshgrid(np.linspace(-1.0, 1.0, W), np.linspace(1.0, -1.0, H) * H / W) | |
| print('x:', lo(x)) | |
| print('y:', lo(y)) | |
| print('z:', lo(z)) | |
| fig.add_trace(go.Surface( | |
| x=x, y=y, z=z, | |
| surfacecolor=self._8bit_image, | |
| cmin=0, | |
| cmax=255, | |
| colorscale=self._image_colorscale, | |
| showscale=False, | |
| lighting_diffuse=1.0, | |
| lighting_ambient=1.0, | |
| lighting_fresnel=1.0, | |
| lighting_roughness=1.0, | |
| lighting_specular=0.3)) | |
| scene_bounds = 3.5 | |
| base_radius = 2.5 | |
| zoom_scale = 1.5 # Note that input radius offset is in [-0.5, 0.5]. | |
| fov_deg = 50.0 | |
| edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)] | |
| input_cone = calc_cam_cone_pts_3d( | |
| 0.0, 0.0, base_radius, fov_deg) # (5, 3). | |
| output_cone = calc_cam_cone_pts_3d( | |
| self._polar, self._azimuth, base_radius + self._radius * zoom_scale, fov_deg) # (5, 3). | |
| # print('input_cone:', lo(input_cone).v) | |
| # print('output_cone:', lo(output_cone).v) | |
| for (cone, clr, legend) in [(input_cone, 'green', 'Input view'), | |
| (output_cone, 'blue', 'Target view')]: | |
| for (i, edge) in enumerate(edges): | |
| (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0]) | |
| (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1]) | |
| (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2]) | |
| fig.add_trace(go.Scatter3d( | |
| x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines', | |
| line=dict(color=clr, width=3), | |
| name=legend, showlegend=(i == 0))) | |
| # text=(legend if i == 0 else None), | |
| # textposition='bottom center')) | |
| # hoverinfo='text', | |
| # hovertext='hovertext')) | |
| # Add label. | |
| if cone[0, 2] <= base_radius / 2.0: | |
| fig.add_trace(go.Scatter3d( | |
| x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False, | |
| mode='text', text=legend, textposition='bottom center')) | |
| else: | |
| fig.add_trace(go.Scatter3d( | |
| x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False, | |
| mode='text', text=legend, textposition='top center')) | |
| # look at center of scene | |
| fig.update_layout( | |
| # width=640, | |
| # height=480, | |
| # height=400, | |
| height=360, | |
| autosize=True, | |
| hovermode=False, | |
| margin=go.layout.Margin(l=0, r=0, b=0, t=0), | |
| showlegend=True, | |
| legend=dict( | |
| yanchor='bottom', | |
| y=0.01, | |
| xanchor='right', | |
| x=0.99, | |
| ), | |
| scene=dict( | |
| aspectmode='manual', | |
| aspectratio=dict(x=1, y=1, z=1.0), | |
| camera=dict( | |
| eye=dict(x=base_radius - 1.6, y=0.0, z=0.6), | |
| center=dict(x=0.0, y=0.0, z=0.0), | |
| up=dict(x=0.0, y=0.0, z=1.0)), | |
| xaxis_title='', | |
| yaxis_title='', | |
| zaxis_title='', | |
| xaxis=dict( | |
| range=[-scene_bounds, scene_bounds], | |
| showticklabels=False, | |
| showgrid=True, | |
| zeroline=False, | |
| showbackground=True, | |
| showspikes=False, | |
| showline=False, | |
| ticks=''), | |
| yaxis=dict( | |
| range=[-scene_bounds, scene_bounds], | |
| showticklabels=False, | |
| showgrid=True, | |
| zeroline=False, | |
| showbackground=True, | |
| showspikes=False, | |
| showline=False, | |
| ticks=''), | |
| zaxis=dict( | |
| range=[-scene_bounds, scene_bounds], | |
| showticklabels=False, | |
| showgrid=True, | |
| zeroline=False, | |
| showbackground=True, | |
| showspikes=False, | |
| showline=False, | |
| ticks=''))) | |
| self._fig = fig | |
| return fig | |
| def preprocess_image(models, input_im, preprocess): | |
| ''' | |
| :param input_im (PIL Image). | |
| :return input_im (H, W, 3) array in [0, 1]. | |
| ''' | |
| print('old input_im:', input_im.size) | |
| start_time = time.time() | |
| if preprocess: | |
| input_im = load_and_preprocess(models['carvekit'], input_im) | |
| input_im = (input_im / 255.0).astype(np.float32) | |
| # (H, W, 3) array in [0, 1]. | |
| else: | |
| input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS) | |
| input_im = np.asarray(input_im, dtype=np.float32) / 255.0 | |
| # (H, W, 4) array in [0, 1]. | |
| # old method: thresholding background, very important | |
| # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.] | |
| # new method: apply correct method of compositing to avoid sudden transitions / thresholding | |
| # (smoothly transition foreground to white background based on alpha values) | |
| alpha = input_im[:, :, 3:4] | |
| white_im = np.ones_like(input_im) | |
| input_im = alpha * input_im + (1.0 - alpha) * white_im | |
| input_im = input_im[:, :, 0:3] | |
| # (H, W, 3) array in [0, 1]. | |
| print(f'Infer foreground mask (preprocess_image) took {time.time() - start_time:.3f}s.') | |
| print('new input_im:', lo(input_im)) | |
| return input_im | |
| def main_run(models, device, cam_vis, return_what, | |
| x=0.0, y=0.0, z=0.0, | |
| raw_im=None, preprocess=True, | |
| scale=3.0, n_samples=4, ddim_steps=50, ddim_eta=1.0, | |
| precision='fp32', h=256, w=256): | |
| ''' | |
| :param raw_im (PIL Image). | |
| ''' | |
| safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device) | |
| (image, has_nsfw_concept) = models['nsfw']( | |
| images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values) | |
| print('has_nsfw_concept:', has_nsfw_concept) | |
| if np.any(has_nsfw_concept): | |
| print('NSFW content detected.') | |
| to_return = [None] * 10 | |
| description = ('### <span style="color:red"> Unfortunately, ' | |
| 'potential NSFW content was detected, ' | |
| 'which is not supported by our model. ' | |
| 'Please try again with a different image. </span>') | |
| if 'angles' in return_what: | |
| to_return[0] = 0.0 | |
| to_return[1] = 0.0 | |
| to_return[2] = 0.0 | |
| to_return[3] = description | |
| else: | |
| to_return[0] = description | |
| return to_return | |
| else: | |
| print('Safety check passed.') | |
| input_im = preprocess_image(models, raw_im, preprocess) | |
| # if np.random.rand() < 0.3: | |
| # description = ('Unfortunately, a human, a face, or potential NSFW content was detected, ' | |
| # 'which is not supported by our model.') | |
| # if vis_only: | |
| # return (None, None, description) | |
| # else: | |
| # return (None, None, None, description) | |
| show_in_im1 = (input_im * 255.0).astype(np.uint8) | |
| show_in_im2 = Image.fromarray(show_in_im1) | |
| if 'rand' in return_what: | |
| x = int(np.round(np.arcsin(np.random.uniform(-1.0, 1.0)) * 160.0 / np.pi)) # [-80, 80]. | |
| y = int(np.round(np.random.uniform(-150.0, 150.0))) | |
| z = 0.0 | |
| cam_vis.polar_change(x) | |
| cam_vis.azimuth_change(y) | |
| cam_vis.radius_change(z) | |
| cam_vis.encode_image(show_in_im1) | |
| new_fig = cam_vis.update_figure() | |
| if 'vis' in return_what: | |
| description = ('The viewpoints are visualized on the top right. ' | |
| 'Click Run Generation to update the results on the bottom right.') | |
| if 'angles' in return_what: | |
| return (x, y, z, description, new_fig, show_in_im2) | |
| else: | |
| return (description, new_fig, show_in_im2) | |
| elif 'gen' in return_what: | |
| input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device) | |
| input_im = input_im * 2 - 1 | |
| input_im = transforms.functional.resize(input_im, [h, w]) | |
| sampler = DDIMSampler(models['turncam']) | |
| # used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way! | |
| used_x = x # NOTE: Set this way for consistency. | |
| x_samples_ddim = sample_model(input_im, models['turncam'], sampler, precision, h, w, | |
| ddim_steps, n_samples, scale, ddim_eta, used_x, y, z) | |
| output_ims = [] | |
| for x_sample in x_samples_ddim: | |
| x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
| output_ims.append(Image.fromarray(x_sample.astype(np.uint8))) | |
| description = None | |
| if 'angles' in return_what: | |
| return (x, y, z, description, new_fig, show_in_im2, output_ims) | |
| else: | |
| return (description, new_fig, show_in_im2, output_ims) | |
| def calc_cam_cone_pts_3d(polar_deg, azimuth_deg, radius_m, fov_deg): | |
| ''' | |
| :param polar_deg (float). | |
| :param azimuth_deg (float). | |
| :param radius_m (float). | |
| :param fov_deg (float). | |
| :return (5, 3) array of float with (x, y, z). | |
| ''' | |
| polar_rad = np.deg2rad(polar_deg) | |
| azimuth_rad = np.deg2rad(azimuth_deg) | |
| fov_rad = np.deg2rad(fov_deg) | |
| polar_rad = -polar_rad # NOTE: Inverse of how used_x relates to x. | |
| # Camera pose center: | |
| cam_x = radius_m * np.cos(azimuth_rad) * np.cos(polar_rad) | |
| cam_y = radius_m * np.sin(azimuth_rad) * np.cos(polar_rad) | |
| cam_z = radius_m * np.sin(polar_rad) | |
| # Obtain four corners of camera frustum, assuming it is looking at origin. | |
| # First, obtain camera extrinsics (rotation matrix only): | |
| camera_R = np.array([[np.cos(azimuth_rad) * np.cos(polar_rad), | |
| -np.sin(azimuth_rad), | |
| -np.cos(azimuth_rad) * np.sin(polar_rad)], | |
| [np.sin(azimuth_rad) * np.cos(polar_rad), | |
| np.cos(azimuth_rad), | |
| -np.sin(azimuth_rad) * np.sin(polar_rad)], | |
| [np.sin(polar_rad), | |
| 0.0, | |
| np.cos(polar_rad)]]) | |
| # print('camera_R:', lo(camera_R).v) | |
| # Multiply by corners in camera space to obtain go to space: | |
| corn1 = [-1.0, np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)] | |
| corn2 = [-1.0, -np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)] | |
| corn3 = [-1.0, -np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)] | |
| corn4 = [-1.0, np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)] | |
| corn1 = np.dot(camera_R, corn1) | |
| corn2 = np.dot(camera_R, corn2) | |
| corn3 = np.dot(camera_R, corn3) | |
| corn4 = np.dot(camera_R, corn4) | |
| # Now attach as offset to actual 3D camera position: | |
| corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2) | |
| corn_x1 = cam_x + corn1[0] | |
| corn_y1 = cam_y + corn1[1] | |
| corn_z1 = cam_z + corn1[2] | |
| corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2) | |
| corn_x2 = cam_x + corn2[0] | |
| corn_y2 = cam_y + corn2[1] | |
| corn_z2 = cam_z + corn2[2] | |
| corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2) | |
| corn_x3 = cam_x + corn3[0] | |
| corn_y3 = cam_y + corn3[1] | |
| corn_z3 = cam_z + corn3[2] | |
| corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2) | |
| corn_x4 = cam_x + corn4[0] | |
| corn_y4 = cam_y + corn4[1] | |
| corn_z4 = cam_z + corn4[2] | |
| xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4] | |
| ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4] | |
| zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4] | |
| return np.array([xs, ys, zs]).T | |
| def run_demo( | |
| device_idx=_GPU_INDEX, | |
| ckpt='105000.ckpt', | |
| config='configs/sd-objaverse-finetune-c_concat-256.yaml'): | |
| print('sys.argv:', sys.argv) | |
| if len(sys.argv) > 1: | |
| print('old device_idx:', device_idx) | |
| device_idx = int(sys.argv[1]) | |
| print('new device_idx:', device_idx) | |
| device = f'cuda:{device_idx}' | |
| config = OmegaConf.load(config) | |
| # Instantiate all models beforehand for efficiency. | |
| models = dict() | |
| print('Instantiating LatentDiffusion...') | |
| models['turncam'] = load_model_from_config(config, ckpt, device=device) | |
| print('Instantiating Carvekit HiInterface...') | |
| models['carvekit'] = create_carvekit_interface() | |
| print('Instantiating StableDiffusionSafetyChecker...') | |
| models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained( | |
| 'CompVis/stable-diffusion-safety-checker').to(device) | |
| print('Instantiating AutoFeatureExtractor...') | |
| models['clip_fe'] = AutoFeatureExtractor.from_pretrained( | |
| 'CompVis/stable-diffusion-safety-checker') | |
| # Reduce NSFW false positives. | |
| # NOTE: At the time of writing, and for diffusers 0.12.1, the default parameters are: | |
| # models['nsfw'].concept_embeds_weights: | |
| # [0.1800, 0.1900, 0.2060, 0.2100, 0.1950, 0.1900, 0.1940, 0.1900, 0.1900, 0.2200, 0.1900, | |
| # 0.1900, 0.1950, 0.1984, 0.2100, 0.2140, 0.2000]. | |
| # models['nsfw'].special_care_embeds_weights: | |
| # [0.1950, 0.2000, 0.2200]. | |
| # We multiply all by some factor > 1 to make them less likely to be triggered. | |
| models['nsfw'].concept_embeds_weights *= 1.07 | |
| models['nsfw'].special_care_embeds_weights *= 1.07 | |
| with open('instructions.md', 'r') as f: | |
| article = f.read() | |
| # NOTE: Examples must match inputs | |
| # [polar_slider, azimuth_slider, radius_slider, image_block, | |
| # preprocess_chk, scale_slider, samples_slider, steps_slider]. | |
| example_fns = ['1_blue_arm.png', '2_cybercar.png', '3_sushi.png', '4_blackarm.png', | |
| '5_cybercar.png', '6_burger.png', '7_london.png', '8_motor.png'] | |
| num_examples = len(example_fns) | |
| example_fps = [os.path.join(os.path.dirname(__file__), 'assets', x) for x in example_fns] | |
| example_angles = [(-40.0, -65.0, 0.0), (-30.0, 90.0, 0.0), (45.0, -15.0, 0.0), (-75.0, 100.0, 0.0), | |
| (-40.0, -75.0, 0.0), (-45.0, 0.0, 0.0), (-55.0, 90.0, 0.0), (-20.0, 125.0, 0.0)] | |
| examples_full = [[*example_angles[i], example_fps[i], True, 3, 4, 50] for i in range(num_examples)] | |
| print('examples_full:', examples_full) | |
| # Compose demo layout & data flow. | |
| demo = gr.Blocks(title=_TITLE) | |
| with demo: | |
| gr.Markdown('# ' + _TITLE) | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=0.9, variant='panel'): | |
| image_block = gr.Image(type='pil', image_mode='RGBA', | |
| label='Input image of single object') | |
| preprocess_chk = gr.Checkbox( | |
| True, label='Preprocess image automatically (remove background and recenter object)') | |
| # info='If enabled, the uploaded image will be preprocessed to remove the background and recenter the object by cropping and/or padding as necessary. ' | |
| # 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'), | |
| gr.Markdown('*Try camera position presets:*') | |
| with gr.Row(): | |
| left_btn = gr.Button('View from the Left', variant='primary') | |
| above_btn = gr.Button('View from Above', variant='primary') | |
| right_btn = gr.Button('View from the Right', variant='primary') | |
| with gr.Row(): | |
| random_btn = gr.Button('Random Rotation', variant='primary') | |
| below_btn = gr.Button('View from Below', variant='primary') | |
| behind_btn = gr.Button('View from Behind', variant='primary') | |
| gr.Markdown('*Control camera position manually:*') | |
| polar_slider = gr.Slider( | |
| -90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)') | |
| # info='Positive values move the camera down, while negative values move the camera up.') | |
| azimuth_slider = gr.Slider( | |
| -180, 180, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)') | |
| # info='Positive values move the camera right, while negative values move the camera left.') | |
| radius_slider = gr.Slider( | |
| -0.5, 0.5, value=0.0, step=0.1, label='Zoom (relative distance from center)') | |
| # info='Positive values move the camera further away, while negative values move the camera closer.') | |
| samples_slider = gr.Slider(1, 8, value=4, step=1, | |
| label='Number of samples to generate') | |
| with gr.Accordion('Advanced options', open=False): | |
| scale_slider = gr.Slider(0, 30, value=3, step=1, | |
| label='Diffusion guidance scale') | |
| steps_slider = gr.Slider(5, 200, value=75, step=5, | |
| label='Number of diffusion inference steps') | |
| with gr.Row(): | |
| vis_btn = gr.Button('Visualize Angles', variant='secondary') | |
| run_btn = gr.Button('Run Generation', variant='primary') | |
| desc_output = gr.Markdown( | |
| 'The results will appear on the right.', visible=_SHOW_DESC) | |
| with gr.Column(scale=1.1, variant='panel'): | |
| vis_output = gr.Plot( | |
| label='Relationship between input (green) and output (blue) camera poses') | |
| gen_output = gr.Gallery(label='Generated images from specified new viewpoint') | |
| gen_output.style(grid=2) | |
| preproc_output = gr.Image(type='pil', image_mode='RGB', | |
| label='Preprocessed input image', visible=_SHOW_INTERMEDIATE) | |
| cam_vis = CameraVisualizer(vis_output) | |
| gr.Examples( | |
| examples=examples_full, # NOTE: elements must match inputs list! | |
| fn=partial(main_run, models, device, cam_vis, 'gen'), | |
| inputs=[polar_slider, azimuth_slider, radius_slider, | |
| image_block, preprocess_chk, | |
| scale_slider, samples_slider, steps_slider], | |
| outputs=[desc_output, vis_output, preproc_output, gen_output], | |
| cache_examples=True, | |
| run_on_click=True, | |
| ) | |
| gr.Markdown(article) | |
| # NOTE: I am forced to update vis_output for these preset buttons, | |
| # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason, | |
| # which might confuse the user into thinking that the plot has been updated too. | |
| vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'), | |
| inputs=[polar_slider, azimuth_slider, radius_slider, | |
| image_block, preprocess_chk], | |
| outputs=[desc_output, vis_output, preproc_output]) | |
| run_btn.click(fn=partial(main_run, models, device, cam_vis, 'gen'), | |
| inputs=[polar_slider, azimuth_slider, radius_slider, | |
| image_block, preprocess_chk, | |
| scale_slider, samples_slider, steps_slider], | |
| outputs=[desc_output, vis_output, preproc_output, gen_output]) | |
| # NEW: | |
| preset_inputs = [image_block, preprocess_chk, | |
| scale_slider, samples_slider, steps_slider] | |
| preset_outputs = [polar_slider, azimuth_slider, radius_slider, | |
| desc_output, vis_output, preproc_output, gen_output] | |
| left_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', | |
| 0.0, -90.0, 0.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', | |
| -90.0, 0.0, 0.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', | |
| 0.0, 90.0, 0.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen', | |
| -1.0, -1.0, -1.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', | |
| 90.0, 0.0, 0.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', | |
| 0.0, 180.0, 0.0), | |
| inputs=preset_inputs, outputs=preset_outputs) | |
| demo.launch(enable_queue=True) | |
| if __name__ == '__main__': | |
| fire.Fire(run_demo) | |