|
|
import gradio as gr
|
|
|
import argparse
|
|
|
import os
|
|
|
|
|
|
import pandas as pd
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
import torch as th
|
|
|
from torchvision import transforms
|
|
|
|
|
|
import diffusers
|
|
|
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, LCMScheduler
|
|
|
import gc
|
|
|
from safetensors import safe_open
|
|
|
|
|
|
from models import SAR2OptUNetv3
|
|
|
from utils import update_args_from_yaml, safe_load
|
|
|
|
|
|
transform_sar = transforms.Compose([
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Resize((256, 256)),
|
|
|
transforms.Normalize((0.5), (0.5)),
|
|
|
])
|
|
|
AVAILABLE_MODELS = {
|
|
|
"Sen12:LCM-Model": "models/model.safetensors",
|
|
|
"Sen12:Org-Model": "models/model_org.safetensors",
|
|
|
}
|
|
|
|
|
|
device = th.device('cuda:0' if th.cuda.is_available() else 'cpu')
|
|
|
|
|
|
def safe_load(model_path):
|
|
|
assert "safetensors" in model_path
|
|
|
state_dict = {}
|
|
|
with safe_open(model_path, framework="pt", device="cpu") as f:
|
|
|
for k in f.keys():
|
|
|
state_dict[k] = f.get_tensor(k)
|
|
|
return state_dict
|
|
|
|
|
|
unet_model = SAR2OptUNetv3(
|
|
|
sample_size=256,
|
|
|
in_channels=4,
|
|
|
out_channels=3,
|
|
|
layers_per_block=2,
|
|
|
block_out_channels=(128, 128, 256, 256, 512, 512),
|
|
|
down_block_types=(
|
|
|
"DownBlock2D",
|
|
|
"DownBlock2D",
|
|
|
"DownBlock2D",
|
|
|
"DownBlock2D",
|
|
|
"AttnDownBlock2D",
|
|
|
"DownBlock2D",
|
|
|
),
|
|
|
up_block_types=(
|
|
|
"UpBlock2D",
|
|
|
"AttnUpBlock2D",
|
|
|
"UpBlock2D",
|
|
|
"UpBlock2D",
|
|
|
"UpBlock2D",
|
|
|
"UpBlock2D",
|
|
|
),
|
|
|
)
|
|
|
|
|
|
print('load unet safetensos done!')
|
|
|
lcm_scheduler = LCMScheduler(num_train_timesteps=1000)
|
|
|
|
|
|
unet_model.to(device)
|
|
|
unet_model.eval()
|
|
|
|
|
|
model_kwargs = {}
|
|
|
|
|
|
|
|
|
def predict(condition, nums_step, model_name):
|
|
|
unet_checkpoint = AVAILABLE_MODELS[model_name]
|
|
|
unet_model.load_state_dict(safe_load(unet_checkpoint), strict=True)
|
|
|
unet_model.eval().to(device)
|
|
|
with th.no_grad():
|
|
|
lcm_scheduler.set_timesteps(nums_step, device=device)
|
|
|
timesteps = lcm_scheduler.timesteps
|
|
|
pred_latent = th.randn(size=[1, 3, 256, 256], device=device)
|
|
|
condition = condition.convert("L")
|
|
|
condition = transform_sar(condition)
|
|
|
condition = th.unsqueeze(condition, 0)
|
|
|
condition = condition.to(device)
|
|
|
for timestep in timesteps:
|
|
|
latent_to_pred = th.cat((pred_latent, condition), dim=1)
|
|
|
model_pred = unet_model(latent_to_pred, timestep)
|
|
|
pred_latent, denoised = lcm_scheduler.step(
|
|
|
model_output=model_pred,
|
|
|
timestep=timestep,
|
|
|
sample=pred_latent,
|
|
|
return_dict=False)
|
|
|
sample = denoised.cpu()
|
|
|
|
|
|
sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
|
|
|
sample = sample.permute(0, 2, 3, 1)
|
|
|
sample = sample.contiguous()
|
|
|
sample = sample.cpu().numpy()
|
|
|
sample = sample.squeeze(0)
|
|
|
sample = Image.fromarray(sample)
|
|
|
return sample
|
|
|
|
|
|
|
|
|
demo = gr.Interface(
|
|
|
fn=predict,
|
|
|
inputs=[gr.Image(type="pil"),
|
|
|
gr.Slider(1, 1000),
|
|
|
gr.Dropdown(
|
|
|
choices=list(AVAILABLE_MODELS.keys()),
|
|
|
value=list(AVAILABLE_MODELS.keys())[0],
|
|
|
label="Choose the Model"),],
|
|
|
|
|
|
outputs=gr.Image(type="pil"),
|
|
|
examples=[
|
|
|
[os.path.join(os.path.dirname(__file__), "sar_1.png"), 8, "Sen12:LCM-Model"],
|
|
|
[os.path.join(os.path.dirname(__file__), "sar_2.png"), 16, "Sen12:LCM-Model"],
|
|
|
[os.path.join(os.path.dirname(__file__), "sar_3.png"), 500, "Sen12:Org-Model"],
|
|
|
[os.path.join(os.path.dirname(__file__), "sar_4.png"), 1000, "Sen12:Org-Model"],
|
|
|
],
|
|
|
title="SAR to Optical Image🚀",
|
|
|
description="""
|
|
|
# 🎯 Instruction
|
|
|
This is a project that converts SAR images into optical images, based on conditional diffusion.
|
|
|
|
|
|
Input a SAR image, and its corresponding optical image will be obtained.
|
|
|
|
|
|
## 📢 Inputs
|
|
|
- `condition`: the SAR image that you want to transfer.
|
|
|
- `timestep_respacing`: the number of iteration steps when inference.
|
|
|
|
|
|
## 🎉 Outputs
|
|
|
- The corresponding optical image.
|
|
|
|
|
|
**Paper** : [Guided Diffusion for Image Generation](https://arxiv.org/abs/2105.05233)
|
|
|
|
|
|
**Github** : https://github.com/Coordi777/Conditional_SAR2OPT
|
|
|
"""
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.launch(server_port=16006)
|
|
|
|