{ "cells": [ { "cell_type": "code", "execution_count": 383, "metadata": {}, "outputs": [], "source": [ "# Loading all the important libraries\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from matplotlib import pyplot as plt\n", "from PIL import Image\n", "import torchvision\n", "from datasets import load_dataset\n", "from torchvision import transforms\n", "from diffusers import StableDiffusionPipeline\n", "from peft import LoraConfig, get_peft_model\n", "from torch.utils.data import DataLoader\n", "from tqdm.auto import tqdm\n", "import torch, os\n", "from torch.amp import autocast, GradScaler\n", "import warnings" ] }, { "cell_type": "code", "execution_count": 384, "metadata": {}, "outputs": [], "source": [ "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 385, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Repo card metadata block was not found. Setting CardData to empty.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Combined dataset size: 24554\n" ] } ], "source": [ "# Load animal dataset from Rapidata-10 and Smithsonian Butterfly dataset.\n", "butterflies = load_dataset(\"huggan/smithsonian_butterflies_subset\", split=\"train\")\n", "animals = load_dataset(\"Rapidata/Animals-10\", split=\"train\")\n", "\n", "# Combine datasets\n", "dataset = concatenate_datasets([animals, butterflies])\n", "print(\"Combined dataset size:\", len(dataset))" ] }, { "cell_type": "code", "execution_count": 386, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 138, "referenced_widgets": [ "a8b35c686e9042d5ba14cd29462d2baf", "499b384d89104156915fb644cae54692", "ab80f5f0927d40d98892cfd57d503266", "84446392113b4a7abeaeea80529d5d99", "57ed90c0acb54ec9884233b4ae4ed81c", "84448187f23f4d418010fdc4a5aa9678", "dfa3ce719ac040c0a9fd0b161a06d20a", "84f312d7b02c4353b99b73ca99d5662b", "3413c74883d740909dea596481092d50", "aee8b5c034bf42e88776b5311a395b1a", "56dc56a5bc3a49d6a7dec255d7c1fc68" ] }, "id": "Fp3D-fRmnlDn", "outputId": "454c42d4-b396-4b6c-e0fe-1e8ce307786b" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "68db9e1e82b1492292b73783dfc5f06d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/6 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Visualize training loss\n", "plt.figure(figsize=(8,5))\n", "plt.plot(train_losses, marker='o', color='blue')\n", "plt.title(\"Stable Diffusion LoRA Fine-tuning Loss Curve\")\n", "plt.xlabel(\"Batch\")\n", "plt.ylabel(\"Mean MSE Loss\")\n", "plt.grid(True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 397, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pipeline saved successfully at models/stable_diffusion_training\n" ] } ], "source": [ "# Attach trained components back to pipeline\n", "unet = unet.merge_and_unload() # Converts back to standard UNet\n", "\n", "# Re-attach to pipeline\n", "pipe.unet = unet\n", "pipe.vae = vae\n", "pipe.text_encoder = text_encoder\n", "pipe.tokenizer = tokenizer\n", "\n", "\n", "# Save the full fine-tuned pipeline\n", "pipe.save_pretrained(out_dir)\n", "pipe.tokenizer.save_pretrained(out_dir)\n", "\n", "print(f\"Pipeline saved successfully at {out_dir}\")" ] }, { "cell_type": "code", "execution_count": 402, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "616198a201fa4310a3d232a32e5a8364", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00" ] }, "execution_count": 402, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Generating an initial image from a promt\n", "init_prompt = \"horse on a beach with waves that reaching its feet and the water splashes after striking its feet\"\n", "generator = torch.Generator(device=device).manual_seed(40)\n", "init_image = pipe(init_prompt, negative_promt=\"blurry, satureated, deformed\", guidance_scale=9, num_inference_steps=30, generator=generator).images[0]\n", "init_image" ] }, { "cell_type": "code", "execution_count": 403, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be7f70f111ab4da58a97366017397492", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/6 [00:00" ] }, "execution_count": 404, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Progressive image generation using the prompts\n", "prompts = [\n", " \"cowboy sitting straight on the horse on a beach with medium waves that reaches its feet and the water splashes after striking its feet\",\n", " \"cowboy sitting straight on the horse on a beach with medium waves that reaches its feet and the water splashes after striking its feet, mountains in the background\"\n", "]\n", "generator = torch.Generator(device=device).manual_seed(42)\n", "\n", "image = init_image\n", "for i, prompt in enumerate(prompts):\n", " image = pipe_img2img(prompt=prompt, negative_promt=\"blurry, deformed, distorted,saturated\", image=image, strength=0.9, guidance_scale=10.0 ,num_inference_steps=100, generator=generator).images[0]\n", "\n", "image" ] }, { "cell_type": "code", "execution_count": 405, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e693bd7c397e4bed86beb6148e5b7848", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='