from core import runner
import torch
from torch import tensor
from PIL import Image
import numpy as np
import torch.nn.functional as F
import gradio as gr
#-----------------------global definitions------------------------#
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'{device.type=}')
description = '
Choose an example below; OR
\
Upload by yourself:
\
1. Upload any test image (query) with any target object you wish to segment
\
2. Upload another image (support) with the target object or a variation of it
\
3. Upload a binary mask that segments the target objet in the support image
\
'
# qimg, simg, smask
example_episodes = [
['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'],
['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'],
['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'],
['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'],
['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'],
['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png']
]
blank_img = './imgs/blank.png'
gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil")
inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')]
if device.type=='cpu':
inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)'))
def prepare_feat_maker():
config = runner.makeConfig()
class DummyDataset:
class_ids = [0]
fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device)
return fake_feat_maker
feat_maker = prepare_feat_maker()
has_fit = False
#-----------------------------------------------------------------#
def reset_layers():
global feat_maker
feat_maker = prepare_feat_maker()
def prepare_batch(q_img_pil, s_img_pil, s_mask_pil):
from data.dataset import FSSDataset
FSSDataset.initialize(img_size=400,datapath='')
q_img_tensor = FSSDataset.transform(q_img_pil)
s_img_tensor = FSSDataset.transform(s_img_pil)
s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L')))
s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze()
add_batch_dim = lambda t: t.unsqueeze(0)
add_kshot_dim = lambda t: t.unsqueeze(1)
fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])}
return fake_batch
norm = lambda t: (t - t.min()) / (t.max() - t.min())
def overlay(img, mask):
#img h,w,3(float) mask h,w(float)
return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5
def from_model(q_img, s_img, s_mask):
batch = prepare_batch(q_img, s_img, s_mask)
sseval = runner.SingleSampleEval(batch, feat_maker)
pred_logits, pred_mask = sseval.forward()
global has_fit
has_fit = True
# logit mask in range from -1 to 1, and mask-overlaid query image 0 to 1
return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy())
def predict(q,s,m,re_adapt,confirmed):
print(f'predict with {re_adapt=}, {confirmed=}')
print(f'{type(q)=}')
is_cache_run = re_adapt is None and confirmed is None
is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) #[2] pointing to support mask
print(f'{is_example=}')
if is_cache_run:
reset_layers()
pred = from_model(q,s,m)
msg = 'Results ready.'
return msg, *pred
elif re_adapt:
if confirmed:
reset_layers()
pred = from_model(q,s,m)
msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters."
return msg, *pred
else:
msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start."
return msg, blank_img, blank_img
else:
if is_example:
msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'."
return msg, blank_img, blank_img
else:
if has_fit:
pred = from_model(q,s,m)
msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'."
return msg, *pred
else:
msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'."
return msg, blank_img, blank_img
gradio_app = gr.Interface(
fn=predict,
inputs=inputs,
outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")],
description=description,
examples=example_episodes,
title="abcdfss",
)
gradio_app.launch()