sdxlsae / app.py
anonymous-author-129's picture
Update app.py
73eb294 verified
import gradio as gr
import os
import torch
from PIL import Image
from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import threading
import spaces
code_to_block = {
"down.2.1": "unet.down_blocks.2.attentions.1",
"mid.0": "unet.mid_block.attentions.0",
"up.0.1": "unet.up_blocks.0.attentions.1",
"up.0.0": "unet.up_blocks.0.attentions.0"
}
lock = threading.Lock()
base_guidance_scale_default = 8.0
turbo_guidance_scale_default = 0.0
def process_cache(cache, saes_dict, timestep=None):
top_features_dict = {}
sparse_maps_dict = {}
for code in code_to_block.keys():
block = code_to_block[code]
sae = saes_dict[code]
diff = cache["output"][block] - cache["input"][block]
if diff.shape[0] == 2: # guidance is on and we need to select the second output
diff = diff[1].unsqueeze(0)
# If a specific timestep is provided, select that timestep from the cached activations
if timestep is not None and timestep < diff.shape[1]:
diff = diff[:, timestep:timestep+1]
diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
with torch.no_grad():
sparse_maps = sae.encode(diff)
averages = torch.mean(sparse_maps, dim=(0, 1))
top_features = torch.topk(averages, 10).indices
top_features_dict[code] = top_features.cpu().tolist()
sparse_maps_dict[code] = sparse_maps.cpu().numpy()
return top_features_dict, sparse_maps_dict
def plot_image_heatmap(cache, block_select, radio):
code = block_select.split()[0]
feature = int(radio)
block = code_to_block[code]
heatmap = cache["heatmaps"][code][:, :, feature]
heatmap = np.kron(heatmap, np.ones((32, 32)))
image = cache["image"].convert("RGBA")
jet = plt.cm.jet
cmap = jet(np.arange(jet.N))
cmap[:1, -1] = 0
cmap[1:, -1] = 0.6
cmap = ListedColormap(cmap)
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
heatmap_rgba = cmap(heatmap)
heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
return heatmap_with_transparency
def create_prompt_part(pipe, saes_dict, demo):
@spaces.GPU
def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None):
lock.acquire()
try:
# Default values
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
default_n_steps = 25 if is_base_model else 1
default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
# Use provided values if available, otherwise use defaults
n_steps = default_n_steps if num_steps is None else int(num_steps)
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
# Convert timestep to integer if it's not None
timestep_int = None if timestep is None else int(timestep)
images, cache = pipe.run_with_cache(
prompt,
positions_to_cache=list(code_to_block.values()),
num_inference_steps=n_steps,
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
save_input=True,
save_output=True
)
finally:
lock.release()
top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, timestep_int)
return images.images[0], {
"image": images.images[0],
"heatmaps": top_sparse_maps_dict,
"features": top_features_dict
}
def update_radio(cache, block_select):
code = block_select.split()[0]
return gr.update(choices=cache["features"][code])
def update_img(cache, block_select, radio):
new_img = plot_image_heatmap(cache, block_select, radio)
return new_img
def update_visibility():
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
return gr.update(visible=is_base_model), gr.update(visible=is_base_model)
with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
cache = gr.State(value={
"image": None,
"heatmaps": None,
"features": []
})
with gr.Row():
with gr.Column(scale=7):
with gr.Row(equal_height=True):
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
button = gr.Button("Generate", elem_classes="generate_button1")
with gr.Row():
image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
with gr.Column(scale=4):
block_select = gr.Dropdown(
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
value="down.2.1 (composition)",
label="Select block",
elem_id="block_select",
interactive=True
)
# Add SDXL base specific controls
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
with gr.Group() as sdxl_base_controls:
steps_slider = gr.Slider(
minimum=1,
maximum=50,
value=25 if is_base_model else 1,
step=1,
label="Number of steps",
elem_id="steps_slider",
interactive=True,
visible=is_base_model
)
guidance_slider = gr.Slider(
minimum=0.0,
maximum=15.0,
value=base_guidance_scale_default if is_base_model else turbo_guidance_scale_default,
step=0.1,
label="Guidance scale",
elem_id="guidance_slider",
interactive=True,
visible=is_base_model
)
# Add timestep selector
n_steps = 25 if is_base_model else 1
timestep_selector = gr.Slider(
minimum=0,
maximum=n_steps-1,
value=None,
step=1,
label="Timestep (leave empty for average across all steps)",
elem_id="timestep_selector",
interactive=True,
visible=is_base_model
)
recompute_button = gr.Button("Recompute", elem_id="recompute_button",
visible=is_base_model)
# Update max timestep when steps change
steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector])
radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
button.click(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
cache.change(update_radio, [cache, block_select], outputs=[radio])
block_select.select(update_radio, [cache, block_select], outputs=[radio])
radio.select(update_img, [cache, block_select, radio], outputs=[image])
recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
demo.load(image_gen, [prompt_field, timestep_selector, steps_slider, guidance_slider], outputs=[image, cache])
return explore_tab
def downsample_mask(image, factor):
downsampled = image.reshape(
(image.shape[0] // factor, factor,
image.shape[1] // factor, factor)
)
downsampled = downsampled.mean(axis=(1, 3))
return downsampled
def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
@spaces.GPU
def image_gen(prompt, num_steps, guidance_scale=None):
lock.acquire()
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
try:
images = pipe.run_with_hooks(
prompt,
position_hook_dict={},
num_inference_steps=int(num_steps),
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
)
finally:
lock.release()
if images.images[0].size == (1024, 1024):
return images.images[0].resize((512, 512)), images.images[0].resize((512, 512))
else:
return images.images[0], images.images[0]
@spaces.GPU
def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, guidance_scale=None, start_index=None, end_index=None):
block = block_str.split(" ")[0]
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
if is_base_model:
mask = downsample_mask(mask, 16)
else:
mask = downsample_mask(mask, 32)
mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
if mask.sum() == 0:
gr.Info("No mask selected, please draw on the input image")
if is_base_model:
# Set default values for start_index and end_index if not provided
if start_index is None:
start_index = 0
if end_index is None:
end_index = int(num_steps)
# Ensure start_index and end_index are within valid ranges
start_index = max(0, min(int(start_index), int(num_steps)))
end_index = max(0, min(int(end_index), int(num_steps)))
# Ensure start_index is less than end_index
if start_index >= end_index:
start_index = max(0, end_index - 1)
def myhook(module, input, output):
return add_feature_on_area_base(
saes_dict[block],
brush_index,
mask * means_dict[block][brush_index] * strength,
module,
input,
output)
hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index))
else:
def hook(module, input, output):
return add_feature_on_area_turbo(
saes_dict[block],
brush_index,
mask * means_dict[block][brush_index] * strength,
module,
input,
output)
lock.acquire()
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
try:
image = pipe.run_with_hooks(
prompt,
position_hook_dict={code_to_block[block]: hook},
num_inference_steps=int(num_steps),
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance
).images[0]
finally:
lock.release()
return image
@spaces.GPU
def feature_icon(block_str, brush_index, guidance_scale=None):
block = block_str.split(" ")[0]
if block in ["mid.0", "up.0.0"]:
gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
def hook(module, input, output):
if is_base_model:
return replace_with_feature_base(
saes_dict[block],
brush_index,
means_dict[block][brush_index] * saes_dict[block].k,
module,
input,
output
)
else:
return replace_with_feature_turbo(
saes_dict[block],
brush_index,
means_dict[block][brush_index] * saes_dict[block].k,
module,
input,
output)
lock.acquire()
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
n_steps = 25 if is_base_model else 1
default_guidance = base_guidance_scale_default if is_base_model else turbo_guidance_scale_default
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
try:
image = pipe.run_with_hooks(
"",
position_hook_dict={code_to_block[block]: hook},
num_inference_steps=n_steps,
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
).images[0]
finally:
lock.release()
return image
with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
image_state = gr.State(value=None)
with gr.Row():
with gr.Column(scale=3):
# Generation column
with gr.Row():
# prompt and num_steps
is_base_model = pipe.pipe.name_or_path == "stabilityai/stable-diffusion-xl-base-1.0"
n_steps = 25 if is_base_model else 1
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, closeup", elem_id="prompt_input")
with gr.Row():
num_steps = gr.Number(value=n_steps, label="Number of steps", minimum=1, maximum=50, elem_id="num_steps", precision=0)
guidance_slider = gr.Slider(
minimum=0.0,
maximum=15.0,
value=base_guidance_scale_default if is_base_model else turbo_guidance_scale_default,
step=0.1,
label="Guidance scale",
elem_id="paint_guidance_slider",
interactive=True,
visible=is_base_model
)
with gr.Row():
# Generate button
button_generate = gr.Button("Generate", elem_id="generate_button")
with gr.Column(scale=3):
# Intervention column
with gr.Row():
# dropdowns and number inputs
with gr.Column(scale=7):
with gr.Row():
block_select = gr.Dropdown(
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
value="down.2.1 (composition)",
label="Select block",
elem_id="block_select"
)
brush_index = gr.Number(value=4998, label="Brush index", minimum=0, maximum=5119, elem_id="brush_index", precision=0)
with gr.Row():
button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
with gr.Row():
gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=is_base_model)
with gr.Row():
start_index = gr.Number(value=5 if is_base_model else 0, label="Start index", minimum=0, maximum=n_steps, elem_id="start_index", precision=0, visible=is_base_model)
end_index = gr.Number(value=20 if is_base_model else 1, label="End index", minimum=0, maximum=n_steps, elem_id="end_index", precision=0, visible=is_base_model)
with gr.Column(scale=3):
with gr.Row():
strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
with gr.Row():
button = gr.Button('Apply', elem_id="apply_button")
with gr.Row():
with gr.Column():
# Input image
i_image = gr.Sketchpad(
height=600,
layers=False, transforms=None, placeholder="Generate and paint!",
container=False,
brush=gr.Brush(default_size=40, color_mode="fixed", colors=['black']),
canvas_size=(512, 512),
label="Input Image")
clear_button = gr.Button("Clear")
clear_button.click(lambda x: x, [image_state], [i_image])
# Output image
o_image = gr.Image(width=512, height=512, label="Output Image")
# Set up the click events
button_generate.click(image_gen, inputs=[prompt_field, num_steps, guidance_slider], outputs=[image_state, o_image])
image_state.change(lambda x: x, [image_state], [i_image])
if is_base_model:
# Update max values for start_index and end_index when num_steps changes
def update_index_maxes(steps):
return gr.update(maximum=steps), gr.update(maximum=steps)
num_steps.change(update_index_maxes, [num_steps], [start_index, end_index])
button.click(image_mod,
inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, guidance_slider, start_index, end_index],
outputs=o_image)
button_icon.click(feature_icon, inputs=[block_select, brush_index, guidance_slider], outputs=o_image)
demo.load(image_gen, [prompt_field, num_steps, guidance_slider], outputs=[image_state, o_image])
return intervene_tab
def create_top_images_part(demo):
def update_top_images(block_select, brush_index):
block = block_select.split(" ")[0]
url = f"https://huggingface.co/datasets/anonymous-author-129/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
return url
with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
with gr.Row():
block_select = gr.Dropdown(
choices=["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
value="down.2.1 (composition)",
label="Select blk"
)
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=5119, precision=0)
with gr.Row():
image = gr.Image(width=600, height=600, label="Top Images")
block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
demo.load(update_top_images, [block_select, brush_index], outputs=[image])
return top_images_tab
def create_intro_part():
with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
gr.Markdown(
'''# One-Step is Enough: Sparse Autoencoders for Text-to-Image Diffusion Models
## Stable Diffustion XL multistep version
## Note
If you encounter GPU time limit errors, don't worry, the app still works and you can use it freely.
## Demo Overview
This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL (Turbo) model.
## How to Use
### Explore
* Enter a prompt in the text box and click on the "Generate" button to generate an image.
* You can observe the active features in different blocks plot on top of the generated image.
### Top Images
* For each feature, you can view the top images that activate the feature the most.
### Paint!
* Generate an image using the prompt.
* Paint on the generated image to apply interventions.
* Use the "Feature Icon" button to understand how the selected brush functions.
### Remarks
* Not all brushes mix well with all images. Experiment with different brushes and strengths.
* Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
* This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
### Interesting features to try
To get started, try the following features:
- down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
- up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
'''
)
return intro_tab
def create_demo(pipe, saes_dict, means_dict):
custom_css = """
.tabs button {
font-size: 20px !important; /* Adjust font size for tab text */
padding: 10px !important; /* Adjust padding to make the tabs bigger */
font-weight: bold !important; /* Adjust font weight to make the text bold */
}
.generate_button1 {
max-width: 160px !important;
margin-top: 20px !important;
margin-bottom: 20px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
with create_intro_part():
pass
with create_prompt_part(pipe, saes_dict, demo):
pass
with create_top_images_part(demo):
pass
with create_intervene_part(pipe, saes_dict, means_dict, demo):
pass
return demo
if __name__ == "__main__":
import os
import gradio as gr
import torch
from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
dtype=torch.float32
pipe = HookedStableDiffusionXLPipeline.from_pretrained(
'stabilityai/stable-diffusion-xl-base-1.0',
torch_dtype=dtype,
variant=("fp16" if dtype==torch.float16 else None)
)
pipe.set_progress_bar_config(disable=True)
pipe.to('cuda')
path_to_checkpoints = './checkpoints/'
code_to_block = {
"down.2.1": "unet.down_blocks.2.attentions.1",
"mid.0": "unet.mid_block.attentions.0",
"up.0.1": "unet.up_blocks.0.attentions.1",
"up.0.0": "unet.up_blocks.0.attentions.0"
}
saes_dict = {}
means_dict = {}
for code, block in code_to_block.items():
sae = SparseAutoencoder.load_from_disk(
os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final"),
)
means = torch.load(
os.path.join(path_to_checkpoints, f"{block}_k10_hidden5120_auxk256_bs4096_lr0.0001", "final", "mean.pt"),
weights_only=True
)
saes_dict[code] = sae.to('cuda', dtype=dtype)
means_dict[code] = means.to('cuda', dtype=dtype)
demo = create_demo(pipe, saes_dict, means_dict)
demo.launch()