Alchemist
Collection
π Dataset and π checkpoints for paper π "Alchemist: Turning Public Text-to-Image Data into Generative Gold"
β’
8 items
β’
Updated
β’
17
BAGEL-7B-MoT Alchemist is T2I-finetuned version of BAGEL-7B-MoT on Alchemist dataset, proposed in the research paper "Alchemist: Turning Public Text-to-Image Data into Generative Gold". Model generates images with improved aesthetics and complexity. Find more details about dataset and training details in the paper.
For installation and usage instructions let's follow the BAGEL's official GitHub repository:
1οΈβ£ Set up environment
git clone https://github.com/bytedance-seed/BAGEL.git
cd BAGEL
conda create -n bagel python=3.10 -y
conda activate bagel
pip install -r requirements.txt
pip install flash_attn==2.5.8 --no-build-isolation
2οΈβ£ Download pretrained checkpoint
from huggingface_hub import snapshot_download
save_dir = "models/BAGEL-7B-MoT-alchemist"
repo_id = "yandex/BAGEL-7B-MoT-alchemist"
cache_dir = save_dir + "/cache"
snapshot_download(cache_dir=cache_dir,
local_dir=save_dir,
repo_id=repo_id,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
)
3οΈβ£ Load BAGEL-Alchemist. Note that it was trained on images with maximum side of 1408 px!
import os
from copy import deepcopy
from typing import (
Any,
AsyncIterable,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
import requests
from io import BytesIO
from PIL import Image
import torch
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
from data.transforms import ImageTransform
from data.data_utils import pil_img2rgb, add_special_tokens
from modeling.bagel import (
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.autoencoder import load_ae
from safetensors.torch import load_file
model_path = "/path/to/BAGEL-7B-MoT-alchemist/weights"
# LLM config preparing
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"
# ViT config preparing
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1
# VAE loading
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
# Bagel config preparing
config = BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
vit_max_num_patch_per_side=70,
connector_act='gelu_pytorch_tanh',
latent_patch_size=2,
max_latent_size=88, # max_latent_size is 88 for BAGEL-alchemist!
)
with init_empty_weights():
language_model = Qwen2ForCausalLM(llm_config)
vit_model = SiglipVisionModel(vit_config)
model = Bagel(language_model, vit_model, config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
# Tokenizer Preparing
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
# Image Transform Preparing
vae_transform = ImageTransform(1408, 512, 16) # maximum image side is 1408 for BAGEL-alchemist!
vit_transform = ImageTransform(980, 224, 14)
max_mem_per_gpu = "40GiB" # Modify it according to your GPU setting. On an A100, 80β―GiB is sufficient to load on a single GPU.
device_map = infer_auto_device_map(
model,
max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)
print(device_map)
same_device_modules = [
'language_model.model.embed_tokens',
'time_embedder',
'latent_pos_embed',
'vae2llm',
'llm2vae',
'connector',
'vit_pos_embed'
]
if torch.cuda.device_count() == 1:
first_device = device_map.get(same_device_modules[0], "cuda:0")
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
else:
device_map[k] = "cuda:0"
else:
first_device = device_map.get(same_device_modules[0])
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
# Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8
model = load_checkpoint_and_dispatch(
model,
checkpoint=os.path.join(model_path, "ema.safetensors"),
device_map=device_map,
offload_buffers=True,
dtype=torch.bfloat16,
force_hooks=True,
offload_folder="/tmp/offload"
)
model = model.eval()
print('Model loaded')
4οΈβ£ Follow final instructions for inference, e.g. T2I inference
from inferencer import InterleaveInferencer
inferencer = InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids
)
inference_hyper=dict(
cfg_text_scale=6.0,
cfg_img_scale=1.0,
cfg_interval=[0.0, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="global",
)
prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
print(prompt)
print('-' * 10)
output_dict = inferencer(text=prompt, **inference_hyper)
display(output_dict['image'])