Spaces:
Running
Running
| import copy | |
| import numpy as np | |
| import torch | |
| import sys | |
| sys.path.append("./") | |
| from models import sam_model_registry | |
| from models.grasp_mods import modify_forward | |
| from models.utils.transforms import ResizeLongestSide | |
| from gradio_image_prompter import ImagePrompter | |
| from structures.grasp_box import GraspCoder | |
| img_resize = ResizeLongestSide(1024) | |
| import cv2 | |
| import gradio as gr | |
| from models.grasp_mods import add_inference_method | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_type = "vit_b" | |
| mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis] | |
| std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis] | |
| sam = sam_model_registry[model_type]() | |
| sam.to(device=device) | |
| sam.forward = modify_forward(sam) | |
| sam.infer = add_inference_method(sam) | |
| pretrained_model_path = "./epoch_9_step_535390.pth" | |
| if pretrained_model_path != "": | |
| sd = torch.load(pretrained_model_path, map_location='cpu') | |
| # strip prefix "module." from keys | |
| new_sd = {} | |
| for k, v in sd.items(): | |
| if k.startswith("module."): | |
| k = k[7:] | |
| new_sd[k] = v | |
| sam.load_state_dict(new_sd) | |
| sam.eval() | |
| def predict(input, topk): | |
| np_image = input["image"] | |
| points = input["points"] | |
| orig_size = np_image.shape[:2] | |
| # normalize image | |
| np_image = np_image.transpose(2, 0, 1) | |
| image = (np_image - mean) / std | |
| image = torch.tensor(image).float().to(device) | |
| image = image.unsqueeze(0) | |
| t_image = img_resize.apply_image_torch(image) | |
| t_orig_size = t_image.shape[-2:] | |
| # pad to 1024x1024 | |
| t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2])) | |
| # get box prompt | |
| valid_boxes = [] | |
| for point in points: | |
| x1, y1, type1, x2, y2, type2 = point | |
| if type1 == 2 and type2 == 3: | |
| valid_boxes.append([x1, y1, x2, y2]) | |
| if len(valid_boxes) == 0: | |
| return np_image | |
| t_boxes = np.array(valid_boxes) | |
| t_boxes = img_resize.apply_boxes(t_boxes, orig_size) | |
| box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device) | |
| batched_inputs = [{"image": t_image[0], "boxes": box_torch}] | |
| with torch.no_grad(): | |
| outputs = sam.infer(batched_inputs, multimask_output=False) | |
| # visualize and post on tensorboard | |
| # recover image | |
| recovered_img = batched_inputs[0]['image'].cpu().numpy() | |
| recovered_img = recovered_img * std + mean | |
| recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255) | |
| for i in range(len(outputs.pred_masks)): | |
| # get predicted mask | |
| pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5 | |
| pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2) | |
| # get predicted grasp | |
| pred_logits = outputs.logits[i].detach().cpu().numpy() | |
| top_ind = pred_logits[:, 0].argsort()[-topk:][::-1] | |
| pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind] | |
| coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp) | |
| _ = coded_grasp.decode() | |
| decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos) | |
| # draw mask | |
| mask_color = np.array([0, 255, 0])[None, None, :] | |
| recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5 | |
| # draw grasp | |
| recovered_img = np.ascontiguousarray(recovered_img) | |
| for grasp in decoded_grasp: | |
| grasp = grasp.astype(int) | |
| cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1) | |
| cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1) | |
| cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2) | |
| cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2) | |
| recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]] | |
| # resize to original size | |
| recovered_img = cv2.resize(recovered_img, (orig_size[1], orig_size[0])) | |
| return recovered_img | |
| if __name__ == "__main__": | |
| app = gr.Blocks(title="GraspAnything") | |
| with app: | |
| gr.Markdown(""" | |
| # GraspAnything <br> | |
| Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object. | |
| """) | |
| with gr.Column(): | |
| prompter = ImagePrompter(show_label=False) | |
| top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps") | |
| with gr.Column(): | |
| image_output = gr.Image() | |
| btn = gr.Button("Generate!") | |
| btn.click(predict, | |
| inputs=[prompter, top_k], | |
| outputs=[image_output]) | |
| app.launch() | |