Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| import pycocotools.mask as mask_utils | |
| # transpose | |
| FLIP_LEFT_RIGHT = 0 | |
| FLIP_TOP_BOTTOM = 1 | |
| class MaskList(object): | |
| """ | |
| This class is unfinished and not meant for use yet | |
| It is supposed to contain the binary masks for all instances in a list of 2D tensors (H, W) | |
| """ | |
| def __init__(self, masks, size, mode): | |
| assert(isinstance(masks, list)) | |
| assert(mode in ['mask', 'rle']) | |
| self.masks = masks | |
| self.size = size # (image_width, image_height) | |
| self.mode = mode | |
| def transpose(self, method): | |
| assert (self.mode == "mask"), "RLE masks cannot be transposed. Please convert them to binary first." | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| # width, height = self.size | |
| masks = np.array(self.masks) | |
| if masks.ndim == 2: | |
| masks = np.expand_dims(masks, axis=0) | |
| if method == FLIP_LEFT_RIGHT: | |
| masks = np.flip(masks, axis=2) | |
| elif method == FLIP_TOP_BOTTOM: | |
| masks = np.flip(masks, axis=1) | |
| flipped_masks = np.split(masks, masks.shape[0]) | |
| flipped_masks = [mask.squeeze(0) for mask in flipped_masks] | |
| return MaskList(flipped_masks, self.size, self.mode) | |
| def resize(self, size, *args, **kwargs): | |
| """ | |
| Resize the binary mask. | |
| :param size: tuple, (image_width, image_height) | |
| :param args: | |
| :param kwargs: | |
| :return: | |
| """ | |
| assert(self.mode == "mask"), "RLE masks cannot be resized. Please convert them to binary first." | |
| cat_mask = np.array(self.masks) | |
| cat_mask = cat_mask.transpose(1, 2, 0) | |
| cat_mask *= 255 | |
| cat_mask = cat_mask.astype(np.uint8) | |
| resized_mask = cv2.resize(cat_mask, size) | |
| if resized_mask.ndim == 2: | |
| resized_mask = np.expand_dims(resized_mask, axis=2) | |
| try: | |
| resized_mask = resized_mask.transpose(2, 0, 1) | |
| except ValueError: | |
| print("?") | |
| resized_mask = resized_mask.astype(int) | |
| resized_mask = resized_mask // 255 | |
| # # visualize to check mask correctness | |
| # from matplotlib import pyplot as plt | |
| # plt.figure() | |
| # plt.imshow(resized_mask[0]*255, cmap='gray') | |
| # plt.show() | |
| mask_list = np.split(resized_mask, resized_mask.shape[0]) | |
| mask_list = [mask.squeeze(0) for mask in mask_list] | |
| return MaskList(mask_list, size, "mask") | |
| def pad(self, size): | |
| """ | |
| pad the binary masks according to the new size. New size must be larger than original size in all dimensions | |
| :param size: New image size, (image_width, image_height) | |
| :return: | |
| """ | |
| assert(size[0] >= self.size[0] and size[1] >= self.size[1]), "New size must be larger than original size in all dimensions" | |
| cat_mask = np.array(self.masks) | |
| if cat_mask.ndim == 2: | |
| cat_mask = np.expand_dims(cat_mask, axis=0) | |
| padded_mask = np.zeros([len(self.masks), size[1], size[0]]) | |
| padded_mask[:, :cat_mask.shape[1], :cat_mask.shape[2]] = cat_mask | |
| # # visualize to check mask correctness | |
| # from matplotlib import pyplot as plt | |
| # plt.figure() | |
| # plt.imshow(padded_mask[1]*255, cmap='gray') | |
| # plt.show() | |
| mask_list = np.split(padded_mask, padded_mask.shape[0]) | |
| mask_list = [mask.squeeze(0) for mask in mask_list] | |
| return MaskList(mask_list, size, "mask") | |
| def convert(self, mode): | |
| """ | |
| Convert mask from between mode "mask" and mode "rle" | |
| :param mode: | |
| :return: | |
| """ | |
| if mode == self.mode: | |
| return self | |
| elif mode == "rle" and self.mode == "mask": | |
| # use pycocotools to encode binary masks to rle | |
| rle_mask_list = mask_utils.encode(np.asfortranarray(np.array(self.masks).transpose(1, 2, 0).astype(np.uint8))) | |
| return MaskList(rle_mask_list, self.size, "rle") | |
| elif mode == "mask" and self.mode == "rle": | |
| # use pycocotools to decode rle to binary masks | |
| bimasks = mask_utils.decode(self.masks) | |
| mask_list = np.split(bimasks.transpose(2, 0, 1), bimasks.shape[2]) | |
| mask_list = [mask.squeeze(0) for mask in mask_list] | |
| return MaskList(mask_list, self.size, "mask") | |
| def bbox(self, bbox_mode="xyxy"): | |
| """ | |
| Generate a bounding box according to the binary mask | |
| :param bbox_mode: | |
| :return: | |
| """ | |
| pass | |
| def __len__(self): | |
| return len(self.masks) | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_masks={}, ".format(len(self)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={}, ".format(self.size[1]) | |
| s += "mode={})".format(self.mode) | |
| return s | |
| class Polygons(object): | |
| """ | |
| This class holds a set of polygons that represents a single instance | |
| of an object mask. The object can be represented as a set of | |
| polygons | |
| """ | |
| def __init__(self, polygons, size, mode): | |
| # assert isinstance(polygons, list), '{}'.format(polygons) | |
| if isinstance(polygons, list): | |
| polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons] | |
| elif isinstance(polygons, Polygons): | |
| polygons = polygons.polygons | |
| self.polygons = polygons | |
| self.size = size | |
| self.mode = mode | |
| def transpose(self, method): | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| flipped_polygons = [] | |
| width, height = self.size | |
| if method == FLIP_LEFT_RIGHT: | |
| dim = width | |
| idx = 0 | |
| elif method == FLIP_TOP_BOTTOM: | |
| dim = height | |
| idx = 1 | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| TO_REMOVE = 1 | |
| p[idx::2] = dim - poly[idx::2] - TO_REMOVE | |
| flipped_polygons.append(p) | |
| return Polygons(flipped_polygons, size=self.size, mode=self.mode) | |
| def crop(self, box): | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| # TODO chck if necessary | |
| w = max(w, 1) | |
| h = max(h, 1) | |
| cropped_polygons = [] | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w) | |
| p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h) | |
| cropped_polygons.append(p) | |
| return Polygons(cropped_polygons, size=(w, h), mode=self.mode) | |
| def resize(self, size, *args, **kwargs): | |
| ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
| if ratios[0] == ratios[1]: | |
| ratio = ratios[0] | |
| scaled_polys = [p * ratio for p in self.polygons] | |
| return Polygons(scaled_polys, size, mode=self.mode) | |
| ratio_w, ratio_h = ratios | |
| scaled_polygons = [] | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| p[0::2] *= ratio_w | |
| p[1::2] *= ratio_h | |
| scaled_polygons.append(p) | |
| return Polygons(scaled_polygons, size=size, mode=self.mode) | |
| def convert(self, mode): | |
| width, height = self.size | |
| if mode == "mask": | |
| rles = mask_utils.frPyObjects( | |
| [p.detach().numpy() for p in self.polygons], height, width | |
| ) | |
| rle = mask_utils.merge(rles) | |
| mask = mask_utils.decode(rle) | |
| mask = torch.from_numpy(mask) | |
| # TODO add squeeze? | |
| return mask | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_polygons={}, ".format(len(self.polygons)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={}, ".format(self.size[1]) | |
| s += "mode={})".format(self.mode) | |
| return s | |
| class SegmentationMask(object): | |
| """ | |
| This class stores the segmentations for all objects in the image | |
| """ | |
| def __init__(self, polygons, size, mode=None): | |
| """ | |
| Arguments: | |
| polygons: a list of list of lists of numbers. The first | |
| level of the list correspond to individual instances, | |
| the second level to all the polygons that compose the | |
| object, and the third level to the polygon coordinates. | |
| """ | |
| assert isinstance(polygons, list) | |
| self.polygons = [Polygons(p, size, mode) for p in polygons] | |
| self.size = size | |
| self.mode = mode | |
| def transpose(self, method): | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| flipped = [] | |
| for polygon in self.polygons: | |
| flipped.append(polygon.transpose(method)) | |
| return SegmentationMask(flipped, size=self.size, mode=self.mode) | |
| def crop(self, box): | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| cropped = [] | |
| for polygon in self.polygons: | |
| cropped.append(polygon.crop(box)) | |
| return SegmentationMask(cropped, size=(w, h), mode=self.mode) | |
| def resize(self, size, *args, **kwargs): | |
| scaled = [] | |
| for polygon in self.polygons: | |
| scaled.append(polygon.resize(size, *args, **kwargs)) | |
| return SegmentationMask(scaled, size=size, mode=self.mode) | |
| def to(self, *args, **kwargs): | |
| return self | |
| def __getitem__(self, item): | |
| if isinstance(item, (int, slice)): | |
| selected_polygons = [self.polygons[item]] | |
| else: | |
| # advanced indexing on a single dimension | |
| selected_polygons = [] | |
| if isinstance(item, torch.Tensor) and item.dtype == torch.bool: | |
| item = item.nonzero() | |
| item = item.squeeze(1) if item.numel() > 0 else item | |
| item = item.tolist() | |
| for i in item: | |
| selected_polygons.append(self.polygons[i]) | |
| return SegmentationMask(selected_polygons, size=self.size, mode=self.mode) | |
| def __iter__(self): | |
| return iter(self.polygons) | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_instances={}, ".format(len(self.polygons)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={})".format(self.size[1]) | |
| return s |