Commit
·
1235e6e
1
Parent(s):
173552f
Update training 2
Browse files- app.py +12 -4
- train_dreambooth.py +68 -11
app.py
CHANGED
|
@@ -30,7 +30,7 @@ maximum_concepts = 3
|
|
| 30 |
|
| 31 |
#Pre download the files
|
| 32 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
| 33 |
-
|
| 34 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
| 35 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
| 36 |
|
|
@@ -171,6 +171,10 @@ def train(*inputs):
|
|
| 171 |
Training_Steps=1400
|
| 172 |
|
| 173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
| 175 |
args_general = argparse.Namespace(
|
| 176 |
image_captions_filename = True,
|
|
@@ -183,7 +187,7 @@ def train(*inputs):
|
|
| 183 |
output_dir="output_model",
|
| 184 |
instance_prompt="",
|
| 185 |
seed=42,
|
| 186 |
-
resolution=
|
| 187 |
mixed_precision="fp16",
|
| 188 |
train_batch_size=1,
|
| 189 |
gradient_accumulation_steps=1,
|
|
@@ -192,6 +196,8 @@ def train(*inputs):
|
|
| 192 |
lr_scheduler="polynomial",
|
| 193 |
lr_warmup_steps = 0,
|
| 194 |
max_train_steps=Training_Steps,
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
print("Starting single training...")
|
| 197 |
lock_file = open("intraining.lock", "w")
|
|
@@ -211,7 +217,7 @@ def train(*inputs):
|
|
| 211 |
prior_loss_weight=1.0,
|
| 212 |
instance_prompt="",
|
| 213 |
seed=42,
|
| 214 |
-
resolution=
|
| 215 |
mixed_precision="fp16",
|
| 216 |
train_batch_size=1,
|
| 217 |
gradient_accumulation_steps=1,
|
|
@@ -220,7 +226,9 @@ def train(*inputs):
|
|
| 220 |
lr_scheduler="polynomial",
|
| 221 |
lr_warmup_steps = 0,
|
| 222 |
max_train_steps=Training_Steps,
|
| 223 |
-
num_class_images=200,
|
|
|
|
|
|
|
| 224 |
)
|
| 225 |
print("Starting multi-training...")
|
| 226 |
lock_file = open("intraining.lock", "w")
|
|
|
|
| 30 |
|
| 31 |
#Pre download the files
|
| 32 |
model_v1 = snapshot_download(repo_id="multimodalart/sd-fine-tunable")
|
| 33 |
+
model_v2 = snapshot_download(repo_id="stabilityai/stable-diffusion-2")
|
| 34 |
model_v2_512 = snapshot_download(repo_id="stabilityai/stable-diffusion-2-base")
|
| 35 |
safety_checker = snapshot_download(repo_id="multimodalart/sd-sc")
|
| 36 |
|
|
|
|
| 171 |
Training_Steps=1400
|
| 172 |
|
| 173 |
stptxt = int((Training_Steps*Train_text_encoder_for)/100)
|
| 174 |
+
#gradient_checkpointing = False if which_model == "v1-5" else True
|
| 175 |
+
gradient_checkpointing=False
|
| 176 |
+
resolution = 512 if which_model != "v2-768" else 768
|
| 177 |
+
cache_latents = True if which_model != "v1-5" else False
|
| 178 |
if (type_of_thing == "object" or type_of_thing == "style" or (type_of_thing == "person" and not experimental_face_improvement)):
|
| 179 |
args_general = argparse.Namespace(
|
| 180 |
image_captions_filename = True,
|
|
|
|
| 187 |
output_dir="output_model",
|
| 188 |
instance_prompt="",
|
| 189 |
seed=42,
|
| 190 |
+
resolution=resolution,
|
| 191 |
mixed_precision="fp16",
|
| 192 |
train_batch_size=1,
|
| 193 |
gradient_accumulation_steps=1,
|
|
|
|
| 196 |
lr_scheduler="polynomial",
|
| 197 |
lr_warmup_steps = 0,
|
| 198 |
max_train_steps=Training_Steps,
|
| 199 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 200 |
+
cache_latents=cache_latents,
|
| 201 |
)
|
| 202 |
print("Starting single training...")
|
| 203 |
lock_file = open("intraining.lock", "w")
|
|
|
|
| 217 |
prior_loss_weight=1.0,
|
| 218 |
instance_prompt="",
|
| 219 |
seed=42,
|
| 220 |
+
resolution=resolution,
|
| 221 |
mixed_precision="fp16",
|
| 222 |
train_batch_size=1,
|
| 223 |
gradient_accumulation_steps=1,
|
|
|
|
| 226 |
lr_scheduler="polynomial",
|
| 227 |
lr_warmup_steps = 0,
|
| 228 |
max_train_steps=Training_Steps,
|
| 229 |
+
num_class_images=200,
|
| 230 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 231 |
+
cache_latents=cache_latents,
|
| 232 |
)
|
| 233 |
print("Starting multi-training...")
|
| 234 |
lock_file = open("intraining.lock", "w")
|
train_dreambooth.py
CHANGED
|
@@ -235,6 +235,13 @@ def parse_args():
|
|
| 235 |
help="Train only the unet",
|
| 236 |
)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
parser.add_argument(
|
| 239 |
"--Session_dir",
|
| 240 |
type=str,
|
|
@@ -382,6 +389,16 @@ class PromptDataset(Dataset):
|
|
| 382 |
example["index"] = index
|
| 383 |
return example
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
| 387 |
if token is None:
|
|
@@ -631,6 +648,28 @@ def run_training(args_imported):
|
|
| 631 |
if not args.train_text_encoder:
|
| 632 |
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 634 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 635 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 636 |
if overrode_max_train_steps:
|
|
@@ -669,8 +708,12 @@ def run_training(args_imported):
|
|
| 669 |
for step, batch in enumerate(train_dataloader):
|
| 670 |
with accelerator.accumulate(unet):
|
| 671 |
# Convert images to latent space
|
| 672 |
-
|
| 673 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
|
| 675 |
# Sample noise that we'll add to the latents
|
| 676 |
noise = torch.randn_like(latents)
|
|
@@ -684,26 +727,40 @@ def run_training(args_imported):
|
|
| 684 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 685 |
|
| 686 |
# Get the text embedding for conditioning
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
# Predict the noise residual
|
| 690 |
-
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 692 |
if args.with_prior_preservation:
|
| 693 |
-
# Chunk the noise and
|
| 694 |
-
|
| 695 |
-
|
| 696 |
|
| 697 |
# Compute instance loss
|
| 698 |
-
loss = F.mse_loss(
|
| 699 |
|
| 700 |
# Compute prior loss
|
| 701 |
-
prior_loss = F.mse_loss(
|
| 702 |
|
| 703 |
# Add the prior loss to the instance loss.
|
| 704 |
loss = loss + args.prior_loss_weight * prior_loss
|
| 705 |
else:
|
| 706 |
-
loss = F.mse_loss(
|
| 707 |
|
| 708 |
accelerator.backward(loss)
|
| 709 |
if accelerator.sync_gradients:
|
|
|
|
| 235 |
help="Train only the unet",
|
| 236 |
)
|
| 237 |
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--cache_latents",
|
| 240 |
+
action="store_true",
|
| 241 |
+
default=False,
|
| 242 |
+
help="Train only the unet",
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
parser.add_argument(
|
| 246 |
"--Session_dir",
|
| 247 |
type=str,
|
|
|
|
| 389 |
example["index"] = index
|
| 390 |
return example
|
| 391 |
|
| 392 |
+
class LatentsDataset(Dataset):
|
| 393 |
+
def __init__(self, latents_cache, text_encoder_cache):
|
| 394 |
+
self.latents_cache = latents_cache
|
| 395 |
+
self.text_encoder_cache = text_encoder_cache
|
| 396 |
+
|
| 397 |
+
def __len__(self):
|
| 398 |
+
return len(self.latents_cache)
|
| 399 |
+
|
| 400 |
+
def __getitem__(self, index):
|
| 401 |
+
return self.latents_cache[index], self.text_encoder_cache[index]
|
| 402 |
|
| 403 |
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
| 404 |
if token is None:
|
|
|
|
| 648 |
if not args.train_text_encoder:
|
| 649 |
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 650 |
|
| 651 |
+
|
| 652 |
+
if args.cache_latents:
|
| 653 |
+
latents_cache = []
|
| 654 |
+
text_encoder_cache = []
|
| 655 |
+
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
| 656 |
+
with torch.no_grad():
|
| 657 |
+
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device, non_blocking=True, dtype=weight_dtype)
|
| 658 |
+
batch["input_ids"] = batch["input_ids"].to(accelerator.device, non_blocking=True)
|
| 659 |
+
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
| 660 |
+
if args.train_text_encoder:
|
| 661 |
+
text_encoder_cache.append(batch["input_ids"])
|
| 662 |
+
else:
|
| 663 |
+
text_encoder_cache.append(text_encoder(batch["input_ids"])[0])
|
| 664 |
+
train_dataset = LatentsDataset(latents_cache, text_encoder_cache)
|
| 665 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x, shuffle=True)
|
| 666 |
+
|
| 667 |
+
del vae
|
| 668 |
+
if not args.train_text_encoder:
|
| 669 |
+
del text_encoder
|
| 670 |
+
if torch.cuda.is_available():
|
| 671 |
+
torch.cuda.empty_cache()
|
| 672 |
+
|
| 673 |
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 674 |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 675 |
if overrode_max_train_steps:
|
|
|
|
| 708 |
for step, batch in enumerate(train_dataloader):
|
| 709 |
with accelerator.accumulate(unet):
|
| 710 |
# Convert images to latent space
|
| 711 |
+
with torch.no_grad():
|
| 712 |
+
if args.cache_latents:
|
| 713 |
+
latents = batch[0][0]
|
| 714 |
+
else:
|
| 715 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 716 |
+
latents = latents * 0.18215
|
| 717 |
|
| 718 |
# Sample noise that we'll add to the latents
|
| 719 |
noise = torch.randn_like(latents)
|
|
|
|
| 727 |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 728 |
|
| 729 |
# Get the text embedding for conditioning
|
| 730 |
+
if(args.cache_latents):
|
| 731 |
+
if args.train_text_encoder:
|
| 732 |
+
encoder_hidden_states = text_encoder(batch[0][1])[0]
|
| 733 |
+
else:
|
| 734 |
+
encoder_hidden_states = batch[0][1]
|
| 735 |
+
else:
|
| 736 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 737 |
|
| 738 |
# Predict the noise residual
|
| 739 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 740 |
+
|
| 741 |
+
# Get the target for loss depending on the prediction type
|
| 742 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 743 |
+
target = noise
|
| 744 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 745 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 746 |
+
else:
|
| 747 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 748 |
+
|
| 749 |
if args.with_prior_preservation:
|
| 750 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
| 751 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
| 752 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 753 |
|
| 754 |
# Compute instance loss
|
| 755 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
| 756 |
|
| 757 |
# Compute prior loss
|
| 758 |
+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
| 759 |
|
| 760 |
# Add the prior loss to the instance loss.
|
| 761 |
loss = loss + args.prior_loss_weight * prior_loss
|
| 762 |
else:
|
| 763 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
| 764 |
|
| 765 |
accelerator.backward(loss)
|
| 766 |
if accelerator.sync_gradients:
|