from core import runner import torch from torch import tensor from PIL import Image import numpy as np import torch.nn.functional as F import gradio as gr #-----------------------global definitions------------------------# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'{device.type=}') description = '

Choose an example below; OR
\ Upload by yourself:
\ 1. Upload any test image (query) with any target object you wish to segment
\ 2. Upload another image (support) with the target object or a variation of it
\ 3. Upload a binary mask that segments the target objet in the support image
\

' # qimg, simg, smask example_episodes = [ ['./imgs/549870_35.jpg', './imgs/457070_00.jpg', './imgs/457070_00.png'], ['./imgs/ISIC_0000372.jpg', './imgs/ISIC_0013176.jpg', './imgs/ISIC_0013176_segmentation.png'], ['./imgs/d_r_450_.jpg', './imgs/d_r_465_.jpg', './imgs/d_r_465_.bmp'], ['./imgs/CHNCXR_0282_0.png', './imgs/CHNCXR_0324_0.png', './imgs/CHNCXR_0324_0_mask.png'], ['./imgs/1.jpg', './imgs/5.jpg', './imgs/5.png'], ['./imgs/cake1.png', './imgs/cake2.png', './imgs/cake2_mask.png'] ] blank_img = './imgs/blank.png' gr_img = lambda name: gr.Image(label=name, sources=['upload', 'webcam'], type="pil") inputs = [gr_img('Query Img'), gr_img('Support Img'), gr_img('Support Mask'), gr.Checkbox(label='re-adapt')] if device.type=='cpu': inputs.append(gr.Checkbox(label='Confirm CPU run (CHOOSE ONLY WHEN REQUESTED)')) def prepare_feat_maker(): config = runner.makeConfig() class DummyDataset: class_ids = [0] fake_feat_maker = runner.makeFeatureMaker(DummyDataset(), config, device=device) return fake_feat_maker feat_maker = prepare_feat_maker() has_fit = False #-----------------------------------------------------------------# def reset_layers(): global feat_maker feat_maker = prepare_feat_maker() def prepare_batch(q_img_pil, s_img_pil, s_mask_pil): from data.dataset import FSSDataset FSSDataset.initialize(img_size=400,datapath='') q_img_tensor = FSSDataset.transform(q_img_pil) s_img_tensor = FSSDataset.transform(s_img_pil) s_mask_tensor = torch.tensor(np.array(s_mask_pil.convert('L'))) s_mask_tensor = F.interpolate(s_mask_tensor.unsqueeze(0).unsqueeze(0).float(), s_img_tensor.size()[-2:], mode='nearest').squeeze() add_batch_dim = lambda t: t.unsqueeze(0) add_kshot_dim = lambda t: t.unsqueeze(1) fake_batch = {'query_img':add_batch_dim(q_img_tensor), 'support_imgs':add_kshot_dim(add_batch_dim(s_img_tensor)), 'support_masks':add_kshot_dim(add_batch_dim(s_mask_tensor)), 'class_id':tensor([0])} return fake_batch norm = lambda t: (t - t.min()) / (t.max() - t.min()) def overlay(img, mask): #img h,w,3(float) mask h,w(float) return norm(img)*0.5 + mask[:,:,np.newaxis]*0.5 def from_model(q_img, s_img, s_mask): batch = prepare_batch(q_img, s_img, s_mask) sseval = runner.SingleSampleEval(batch, feat_maker) pred_logits, pred_mask = sseval.forward() global has_fit has_fit = True # logit mask in range from -1 to 1, and mask-overlaid query image 0 to 1 return norm(pred_logits[0].numpy()), overlay(batch['query_img'][0].permute(1,2,0).numpy(), pred_mask[0].numpy()) def predict(q,s,m,re_adapt,confirmed): print(f'predict with {re_adapt=}, {confirmed=}') print(f'{type(q)=}') is_cache_run = re_adapt is None and confirmed is None is_example = any([(np.array_equal(np.array(m),np.array(Image.open(e[2])))) for e in example_episodes]) #[2] pointing to support mask print(f'{is_example=}') if is_cache_run: reset_layers() pred = from_model(q,s,m) msg = 'Results ready.' return msg, *pred elif re_adapt: if confirmed: reset_layers() pred = from_model(q,s,m) msg = "Results ready.\nRemember to untick 're-adapt' if you wish to predict more images with the same parameters." return msg, *pred else: msg = "You chose to re-adapt but are on CPU.\nThis may take 1 minute on your local machine or 4 minutes on huggingface space.\nSelect 'Confirm CPU run' to start." return msg, blank_img, blank_img else: if is_example: msg = "Cached results for example have been shown previously already.\nTo view it again, click the example again.\nTo run adaption again from scratch, select 're-adapt'." return msg, blank_img, blank_img else: if has_fit: pred = from_model(q,s,m) msg = "Results predicted based on layers fitted from previous run.\nIf you wish to re-adapt, select 're-adapt'." return msg, *pred else: msg = "This is the first time you predict own images.\nThe attached layers need to be fitted.\nPlease select 're-adapt'." return msg, blank_img, blank_img gradio_app = gr.Interface( fn=predict, inputs=inputs, outputs=[gr.Textbox(label="Status"), gr.Image(label="Coarse Query Prediction"), gr.Image(label="Mask Prediction")], description=description, examples=example_episodes, title="abcdfss", ) gradio_app.launch()