Spaces:
Sleeping
Sleeping
| import torch | |
| import types | |
| import timm | |
| import requests | |
| import random | |
| import yaml | |
| import gradio as gr | |
| from PIL import Image | |
| from timm import create_model | |
| from torchvision import transforms | |
| from timm.data import resolve_data_config | |
| from modelguidedattacks.guides.unguided import Unguided | |
| from timm.data.transforms_factory import create_transform | |
| from modelguidedattacks.cls_models.registry import TimmPretrainModelWrapper | |
| # Download human-readable labels for ImageNet. | |
| IMAGENET_LABELS_URL = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" | |
| LABELS = requests.get(IMAGENET_LABELS_URL).text.strip().split("\n") | |
| SORTED_LABELS = sorted(LABELS.copy(), key=lambda s: s.lower()) | |
| def get_timm_model(name): | |
| """Retrieves model from timm library by name with weights loaded. | |
| """ | |
| model = create_model(name,pretrained="true") | |
| transform = create_transform(**resolve_data_config({}, model=model)) | |
| model = model.eval() | |
| return model, transform | |
| def create_attacker(model, transform, iterations): | |
| """ Instantiates an QuadAttack Model. | |
| """ | |
| # config_dict = {"cvx_proj_margin" : 0.2, | |
| # "opt_warmup_its": 5} | |
| with open("base_config.yaml") as f: | |
| config_dict = yaml.safe_load(f) | |
| config = types.SimpleNamespace(**config_dict) | |
| attacker = Unguided(TimmPretrainModelWrapper(model, transform,"", "", ""), config, iterations=iterations, | |
| lr=0.002, topk_loss_coef_upper=10) | |
| return attacker | |
| def predict_topk_accuracies(img, k, iters, model_name, desired_labels, button=None, progress=gr.Progress(track_tqdm=True)): | |
| """ Predict the top K results using base model and attacker model. | |
| """ | |
| label_inds = list(range(0,1000)) #label indices | |
| # convert user desired labels to desired inds | |
| desired_inds = [LABELS.index(name) for name in desired_labels] | |
| # remove selected before randomly sampling the rest | |
| for ind in desired_inds: | |
| label_inds.remove(ind) | |
| # fill up user selections to top k results | |
| desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds)) | |
| tensorized_desired_inds = torch.tensor(desired_inds).unsqueeze(0) #[B,K] | |
| model, transform = get_timm_model(model_name) | |
| # Define a transformation to convert PIL image to a tensor | |
| normalization = transforms.Compose([ | |
| transform.transforms[-1] # Converts to a PyTorch tensor | |
| ]) | |
| preprocess = transforms.Compose( | |
| transform.transforms[:-1] # Converts to a PyTorch tensor | |
| ) | |
| attacker = create_attacker(model, normalization, iters) | |
| img = img.convert('RGB') | |
| orig_img = img.copy() | |
| orig_img = preprocess(orig_img) | |
| orig_img = orig_img.unsqueeze(0) | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| attack_outputs, attack_img = attacker(orig_img, tensorized_desired_inds, None) | |
| probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
| attacker_probs = torch.nn.functional.softmax(attack_outputs[0], dim=0) | |
| values, indices = torch.topk(probabilities, k) | |
| attack_vals, attack_inds = torch.topk(attacker_probs, k) | |
| attack_img_out = orig_img + attack_img #B C H W | |
| # Convert the PyTorch tensor to a NumPy array | |
| attack_img_out = attack_img_out.squeeze(0) # C H W | |
| attack_img_out = attack_img_out.permute(1, 2, 0).numpy() # H W C | |
| orig_img = orig_img.squeeze(0) | |
| orig_img = orig_img.permute(1, 2, 0).numpy() | |
| attack_img = attack_img.squeeze(0) | |
| attack_img = attack_img.permute(1, 2, 0).numpy() | |
| # Convert the NumPy array to a PIL image | |
| attack_img_out = Image.fromarray((attack_img_out * 255).astype('uint8')) | |
| orig_img = Image.fromarray((orig_img * 255).astype('uint8')) | |
| attack_img = Image.fromarray((attack_img * 255).astype('uint8')) | |
| return (orig_img, attack_img_out, attack_img,{LABELS[i]: v.item() for i, v in zip(indices, values)}, {LABELS[i]: v.item() for i, v in zip(attack_inds, attack_vals)}) | |
| def random_fill_classes(desired_labels, k): | |
| label_inds = list(range(0,1000)) #label indices | |
| # convert user desired labels to desired inds | |
| if len(desired_labels) > k: | |
| desired_labels = desired_labels[:k] | |
| desired_inds = [LABELS.index(name) for name in desired_labels] | |
| # remove selected before randomly sampling the rest | |
| for ind in desired_inds: | |
| label_inds.remove(ind) | |
| # fill up user selections to top k results | |
| desired_inds = desired_inds + random.sample(label_inds,k-len(desired_inds)) | |
| return [LABELS[ind] for ind in desired_inds] | |
| input_img = gr.Image(type='pil') | |
| top_k_slider = gr.Slider(2, 20, value=10, step=1, label="Top K predictions", info="Choose between 2 and 20") | |
| iteration_slider = gr.Slider(30, 1000, value=60, step=1, label="QuadAttack Iterations", info="Choose how many iterations to optimize using QuadAttack! (Usually <= 60 is enough)") | |
| model_choice_list = gr.Dropdown( | |
| timm.list_models(), value="vit_base_patch16_224", label="timm model name", info="Currently only supporting timm models! See code for models used in paper." | |
| ) | |
| desired_labels = gr.Dropdown( | |
| SORTED_LABELS, max_choices=20,filterable=True, multiselect=True, label="Desired Labels for QuadAttack", info="Select classes you wish to output from an attack. \ | |
| Classes will be ranked in order listed and randomly filled up to \ | |
| K if < K options are selected." | |
| ) | |
| button = gr.Button("Randomly fill Top-K attack classes.") | |
| desc = r'<div align="center">Authors: Thomas Paniagua, Ryan Grainger, Tianfu Wu <p><a href="https://arxiv.org/abs/2312.11510">Paper</a><br><a href="https://github.com/thomaspaniagua/quadattack">Code</a></p> </div>' | |
| with gr.Interface(predict_topk_accuracies, | |
| inputs=[input_img, | |
| top_k_slider, | |
| iteration_slider, | |
| model_choice_list, | |
| desired_labels, | |
| button], | |
| outputs=[ | |
| gr.Image(type='pil', label="Input Image"), | |
| gr.Image(type='pil', label="Perturbed Image"), | |
| gr.Image(type='pil', label="Added Noise"), | |
| gr.Label(label="Original Top K"), | |
| gr.Label(label="QuadAttack Top K"), | |
| # gr.Image(type='pil', label="Perturbed Image") | |
| ], | |
| title='QuadAttack!', | |
| description= desc, | |
| cache_examples=False, | |
| allow_flagging="never", | |
| thumbnail= "quadattack_pipeline.pdf", | |
| examples = [["image_examples/RV.jpeg", 5, 30, "vit_base_patch16_224", None, None | |
| # ["lemon", "plastic_bag", "hay", "tripod", "bell_cote, bell_cot"] | |
| ], | |
| # ["image_examples/biker.jpeg", 10, 60, "swinv2_cr_base_224", None, None | |
| # ["hog, pig, grunter, squealer, Sus_scrofa", | |
| # "lesser_panda, red_panda, panda, bear_cat, cat_bear, Ailurus_fulgens", | |
| # "caldron, cauldron", "dowitcher", "water_tower", "quill, quill_pen", | |
| # "balance_beam, beam", "unicycle, monocycle", "pencil_sharpener", | |
| # "puffer, pufferfish, blowfish, globefish" | |
| # ] | |
| # ], | |
| ["image_examples/mower.jpeg", 15, 100,"wide_resnet101_2", None , None | |
| # ["washbasin, handbasin, washbowl, lavabo, wash-hand_basin", | |
| # "cucumber, cuke", "bolete", "oboe, hautboy, hautboi", "crane", | |
| # "wolf_spider, hunting_spider", "Norfolk_terrier", "nail", "sidewinder, horned_rattlesnake, Crotalus_cerastes", | |
| # "cannon", "beaker", "Shetland_sheepdog, Shetland_sheep_dog, Shetland", | |
| # "monitor", "restaurant, eating_house, eating_place, eatery", "electric_fan, blower" | |
| # ] | |
| ], | |
| # ["image_examples/dog.jpeg", 20, 150, "xcit_small_12_p8_224", None, None | |
| # ["church, church_building", "axolotl, mud_puppy, Ambystoma_mexicanum", | |
| # "Scotch_terrier, Scottish_terrier, Scottie", "black-footed_ferret, ferret, Mustela_nigripes", | |
| # "lab_coat, laboratory_coat", "gyromitra", "grasshopper, hopper", "snail", "tabby, tabby_cat", | |
| # "bell_cote, bell_cot", "Indian_cobra, Naja_naja", "robin, American_robin, Turdus_migratorius", | |
| # "tiger_cat", "book_jacket, dust_cover, dust_jacket, dust_wrapper", "loudspeaker, speaker, speaker_unit, loudspeaker_system, speaker_system", | |
| # "washbasin, handbasin, washbowl, lavabo, wash-hand_basin", "electric_guitar", "armadillo", "ski_mask", | |
| # "convertible" | |
| # ] | |
| # ], | |
| ["image_examples/fish.jpeg", 10, 100, "pvt_v2_b0", None, None | |
| # ["ground_beetle, carabid_beetle", "sunscreen, sunblock, sun_blocker", "brass, memorial_tablet, plaque", "Irish_terrier", "head_cabbage", "bathtub, bathing_tub, bath, tub", | |
| # "centipede", "squirrel_monkey, Saimiri_sciureus", "Chihuahua", "hourglass" | |
| # ] | |
| ] | |
| ] | |
| ).queue() as app: | |
| #turn off clear button as it erases globals | |
| for block in app.blocks: | |
| if isinstance(app.blocks[block],gr.Button): | |
| if app.blocks[block].value == "Clear": | |
| app.blocks[block].visible=False | |
| button.click(random_fill_classes, inputs=[desired_labels,top_k_slider], outputs=desired_labels) | |
| if __name__ == "__main__": | |
| app.launch(server_port=9000) |