|
|
import os
|
|
|
import random
|
|
|
import re
|
|
|
import copy
|
|
|
from dataclasses import dataclass, field
|
|
|
import json
|
|
|
import logging
|
|
|
import pathlib
|
|
|
from typing import Dict, Optional, Sequence, List
|
|
|
import bisect
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import transformers
|
|
|
from objectrelator.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
|
|
|
DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX, DEFAULT_CLS_TOKEN, CLS_TOKEN_INDEX, DEFAULT_REGION_TOKEN, \
|
|
|
REGION_TOKEN_INDEX, REFER_TOKEN_INDEX
|
|
|
from torch.utils.data import Dataset
|
|
|
from objectrelator import conversation as conversation_lib
|
|
|
from objectrelator.model import *
|
|
|
from objectrelator.mm_utils import tokenizer_image_token
|
|
|
from objectrelator.mask_config.data_args import TrainingArguments
|
|
|
from PIL import Image
|
|
|
from objectrelator.mask_config.config import Config
|
|
|
from fvcore.common.config import CfgNode
|
|
|
from detectron2.structures import BoxMode
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings('ignore')
|
|
|
local_rank = None
|
|
|
|
|
|
|
|
|
parser = transformers.HfArgumentParser(TrainingArguments)
|
|
|
training_args = TrainingArguments()
|
|
|
|
|
|
def get_mask_config(config='./objectrelator/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml'):
|
|
|
cfg_coco = Config.fromfile(config)
|
|
|
cfg_base = CfgNode.load_yaml_with_base(config, allow_unsafe=True)
|
|
|
cfg_base.update(cfg_coco.__dict__.items())
|
|
|
cfg = cfg_base
|
|
|
cfg = Config(cfg)
|
|
|
return cfg
|
|
|
class COCO_panoptic_dataset(Dataset):
|
|
|
def __init__(self, json_path, tokenizer, data_args, is_train=True):
|
|
|
super(COCO_panoptic_dataset).__init__()
|
|
|
if is_train:
|
|
|
self.panoptic_gt_path = os.path.join(json_path,'panoptic_train2017')
|
|
|
self.panoptic_image_path = os.path.join(json_path,'train2017')
|
|
|
self.panoptic_json_path = os.path.join(json_path,'annotations/panoptic_train2017.json')
|
|
|
self.semantic_gt_path = os.path.join(json_path,'panoptic_semseg_train2017')
|
|
|
else:
|
|
|
self.panoptic_gt_path = os.path.join(json_path,'panoptic_val2017')
|
|
|
self.panoptic_image_path = os.path.join(json_path,'val2017')
|
|
|
self.panoptic_json_path = os.path.join(json_path,'annotations/panoptic_val2017.json')
|
|
|
self.semantic_gt_path = os.path.join(json_path,'panoptic_semseg_val2017')
|
|
|
|
|
|
with open(self.panoptic_json_path) as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
self.data = data['annotations']
|
|
|
self.tokenizer = tokenizer
|
|
|
self.data_args = data_args
|
|
|
self.mask_format = 'polygon'
|
|
|
coco_class_ids = [cat['id'] for cat in data['categories']]
|
|
|
coco_class_name = [cat['name'] for cat in data['categories']]
|
|
|
coco_is_thing = [cat['isthing'] for cat in data['categories']]
|
|
|
self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
|
|
self.coco_class_name = coco_class_name + ['background']
|
|
|
self.coco_is_thing = coco_is_thing
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
def preprocess_multimodal(self, sources):
|
|
|
for source in sources:
|
|
|
for sentence in source:
|
|
|
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
|
|
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
|
|
sentence['value'] = sentence['value'].strip()
|
|
|
if "mmtag" in conversation_lib.default_conversation.version:
|
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
|
|
'<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
|
|
|
|
|
if DEFAULT_SEG_TOKEN in sentence['value']:
|
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_SEG_TOKEN, '').strip()
|
|
|
sentence['value'] = sentence['value'] + '\n' + DEFAULT_SEG_TOKEN
|
|
|
sentence['value'] = sentence['value']
|
|
|
return sources
|
|
|
|
|
|
def preprocess_llama2(self, sources, tokenizer):
|
|
|
conv = conversation_lib.default_conversation.copy()
|
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
|
|
|
|
|
|
|
|
conversations = []
|
|
|
for i, source in enumerate(sources):
|
|
|
if roles[source[0]["from"]] != conv.roles[0]:
|
|
|
|
|
|
source = source[1:]
|
|
|
|
|
|
conv.messages = []
|
|
|
for j, sentence in enumerate(source):
|
|
|
role = roles[sentence["from"]]
|
|
|
assert role == conv.roles[j % 2], f"{i}"
|
|
|
conv.append_message(role, sentence["value"])
|
|
|
conversations.append(conv.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.stack(
|
|
|
[self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
|
|
|
|
|
targets = input_ids.clone()
|
|
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
|
|
|
|
|
|
|
|
sep = "[/INST] "
|
|
|
for conversation, target in zip(conversations, targets):
|
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
|
|
|
|
|
rounds = conversation.split(conv.sep2)
|
|
|
cur_len = 1
|
|
|
target[:cur_len] = IGNORE_INDEX
|
|
|
for i, rou in enumerate(rounds):
|
|
|
if rou == "":
|
|
|
break
|
|
|
|
|
|
parts = rou.split(sep)
|
|
|
if len(parts) != 2:
|
|
|
break
|
|
|
parts[0] += sep
|
|
|
|
|
|
round_len = len(self.tokenizer_special_tokens(rou, tokenizer))
|
|
|
instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 2
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
|
|
|
|
|
cur_len += round_len
|
|
|
target[cur_len:] = IGNORE_INDEX
|
|
|
|
|
|
if cur_len < tokenizer.model_max_length:
|
|
|
if cur_len != total_len:
|
|
|
target[:] = IGNORE_INDEX
|
|
|
|
|
|
return dict(
|
|
|
input_ids=input_ids,
|
|
|
labels=targets,
|
|
|
)
|
|
|
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
return class_name_id, cls_indices
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_id = int(data["image_id"])
|
|
|
image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = image_file
|
|
|
data_dict['image_id'] = image_id
|
|
|
label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
|
|
sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
|
|
data_dict['pan_seg_file_name'] = label_file
|
|
|
data_dict['sem_seg_file_name'] = sem_label_file
|
|
|
segments_info = data["segments_info"]
|
|
|
for seg in segments_info:
|
|
|
seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
|
|
data_dict['segments_info'] = segments_info
|
|
|
|
|
|
if isinstance(self.data_args.image_processor, dict):
|
|
|
processor = self.data_args.image_processor['panoptic']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
|
|
instruction = 'Panoptic Segmentation: You need to segment all objects '
|
|
|
prefix_inst = 'This is an image <image>, Please do Panoptic Segmentation.'
|
|
|
|
|
|
num_class = len(self.coco_class_name)
|
|
|
category = '<cls>, ' * (num_class-1) + '<cls>.'
|
|
|
|
|
|
sources_value = f'\nThis is all the candidate categories: {category}\n'
|
|
|
|
|
|
sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
|
|
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
|
|
|
class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
|
|
class_name_embedding_indices = torch.zeros_like(input_ids)
|
|
|
class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
|
|
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
|
|
|
data_dict['class_name_ids'] = class_name_ids
|
|
|
data_dict['cls_indices'] = cls_indices
|
|
|
data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
return data_dict
|
|
|
|
|
|
class COCO_interactive_dataset_train(COCO_panoptic_dataset):
|
|
|
def __init__(self, json_path, tokenizer, data_args):
|
|
|
if isinstance(json_path, list):
|
|
|
data = []
|
|
|
for path in json_path:
|
|
|
with open(path) as f:
|
|
|
cur_data = json.load(f)
|
|
|
data.extend(cur_data)
|
|
|
else:
|
|
|
with open(json_path) as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
self.data = data
|
|
|
|
|
|
|
|
|
if training_args.first_stage:
|
|
|
subset_size = len(self.data) // 20
|
|
|
self.data = random.sample(self.data, subset_size)
|
|
|
print('!!!!!!!!!!!!!!!!!!!!!!! Len of Stage1 Training;!!!!!!!!!!!!!!!!!!', len(self.data))
|
|
|
|
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
self.data_args = data_args
|
|
|
coco_class_ids = [
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
|
|
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
|
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
|
|
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
|
|
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
|
|
82, 84, 85, 86, 87, 88, 89, 90
|
|
|
]
|
|
|
coco_class_name = [
|
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
|
|
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
|
|
'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
|
|
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
|
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
|
|
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
|
|
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
|
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
|
]
|
|
|
|
|
|
self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
|
|
self.coco_class_name = coco_class_name + ['background']
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
return class_name_id, cls_indices
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_file = data['image']
|
|
|
image_folder = self.data_args.image_folder
|
|
|
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict['height'] = data['image_info']['height']
|
|
|
data_dict['width'] = data['image_info']['width']
|
|
|
data_dict['image_id'] = data['new_img_id']
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
for annotation in data_dict['annotations']:
|
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
|
|
|
|
|
if annotation['category_id'] in self.coco_id_to_cont_id:
|
|
|
annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
|
|
elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
annotation['category_id'] = annotation['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
annotation['image_id'] = data['new_img_id']
|
|
|
|
|
|
if isinstance(self.data_args.image_processor,dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
|
|
if region_mask_type is not None:
|
|
|
region_mask_type = region_mask_type.split('||')
|
|
|
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
|
|
|
|
|
num_target = len(data_dict['instances'])
|
|
|
prefix_inst = 'This is an image <image>, Please segment by given regions'
|
|
|
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
|
|
sources_value = f'\nThis is all regions: {regions_inst}\n'
|
|
|
|
|
|
sources = [
|
|
|
[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
data_dict['input_ids'] = input_ids
|
|
|
data_dict['labels'] = labels
|
|
|
data_dict['dataset_type'] = 'region_coco'
|
|
|
|
|
|
return data_dict
|
|
|
|
|
|
class COCO_interactive_dataset_eval(COCO_panoptic_dataset):
|
|
|
def __init__(self, data_list, tokenizer, data_args):
|
|
|
|
|
|
data = data_list
|
|
|
self.data = data
|
|
|
self.tokenizer = tokenizer
|
|
|
self.data_args = data_args
|
|
|
coco_class_ids = [
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
|
|
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
|
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
|
|
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
|
|
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
|
|
82, 84, 85, 86, 87, 88, 89, 90
|
|
|
]
|
|
|
coco_class_name = [
|
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
|
|
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
|
|
'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
|
|
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
|
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
|
|
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
|
|
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
|
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
|
]
|
|
|
|
|
|
self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
|
|
self.coco_class_name = coco_class_name + ['background']
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
return class_name_id, cls_indices
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_file = data['image']
|
|
|
image_folder = self.data_args.image_folder
|
|
|
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict['height'] = data['image_info']['height']
|
|
|
data_dict['width'] = data['image_info']['width']
|
|
|
data_dict['image_id'] = data['new_img_id']
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
for annotation in data_dict['annotations']:
|
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
|
|
|
|
|
if annotation['category_id'] in self.coco_id_to_cont_id:
|
|
|
annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
|
|
elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
annotation['category_id'] = annotation['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
annotation['image_id'] = data['new_img_id']
|
|
|
|
|
|
if isinstance(self.data_args.image_processor,dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
|
|
if region_mask_type is not None:
|
|
|
region_mask_type = region_mask_type.split('||')
|
|
|
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
|
|
|
|
|
num_target = len(data_dict['instances'])
|
|
|
prefix_inst = 'This is an image <image>, Please segment by given regions'
|
|
|
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
|
|
sources_value = f'\nThis is all regions: {regions_inst}\n'
|
|
|
|
|
|
sources = [
|
|
|
[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
data_dict['input_ids'] = input_ids
|
|
|
data_dict['labels'] = labels
|
|
|
data_dict['dataset_type'] = 'region_coco'
|
|
|
|
|
|
return data_dict
|
|
|
|
|
|
class COCO_interactive_dataset(COCO_panoptic_dataset):
|
|
|
def __init__(self, json_path, tokenizer, data_args):
|
|
|
if isinstance(json_path, list):
|
|
|
data = []
|
|
|
for path in json_path:
|
|
|
with open(path) as f:
|
|
|
cur_data = json.load(f)
|
|
|
data.extend(cur_data)
|
|
|
else:
|
|
|
with open(json_path) as f:
|
|
|
data = json.load(f)
|
|
|
|
|
|
self.data = data
|
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
self.data_args = data_args
|
|
|
coco_class_ids = [
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
|
|
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
|
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
|
|
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
|
|
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
|
|
82, 84, 85, 86, 87, 88, 89, 90
|
|
|
]
|
|
|
coco_class_name = [
|
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
|
|
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
|
|
'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
|
|
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
|
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
|
|
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
|
|
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
|
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
|
]
|
|
|
|
|
|
self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
|
|
self.coco_class_name = coco_class_name + ['background']
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
|
|
|
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
return class_name_id, cls_indices
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_file = data['image']
|
|
|
image_folder = self.data_args.image_folder
|
|
|
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict['height'] = data['image_info']['height']
|
|
|
data_dict['width'] = data['image_info']['width']
|
|
|
data_dict['image_id'] = data['new_img_id']
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
for annotation in data_dict['annotations']:
|
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
|
|
|
|
|
if annotation['category_id'] in self.coco_id_to_cont_id:
|
|
|
annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
|
|
elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
annotation['category_id'] = annotation['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
annotation['image_id'] = data['new_img_id']
|
|
|
|
|
|
if isinstance(self.data_args.image_processor,dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
region_mask_type = getattr(self.data_args,'region_mask_type',None)
|
|
|
if region_mask_type is not None:
|
|
|
region_mask_type = region_mask_type.split('||')
|
|
|
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type)
|
|
|
|
|
|
num_target = len(data_dict['instances'])
|
|
|
prefix_inst = 'This is an image <image>, Please segment by given regions'
|
|
|
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
|
|
|
sources_value = f'\nThis is all regions: {regions_inst}\n'
|
|
|
|
|
|
sources = [
|
|
|
[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
data_dict['input_ids'] = input_ids
|
|
|
data_dict['labels'] = labels
|
|
|
data_dict['dataset_type'] = 'region_coco'
|
|
|
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
class COCO_instance_dataset(COCO_interactive_dataset_train):
|
|
|
def __init__(self, json_path, tokenizer, data_args):
|
|
|
if isinstance(json_path, list):
|
|
|
data = []
|
|
|
for path in json_path:
|
|
|
with open(path) as f:
|
|
|
cur_data = json.load(f)
|
|
|
data.extend(cur_data)
|
|
|
else:
|
|
|
with open(json_path) as f:
|
|
|
data = json.load(f)
|
|
|
self.data = data
|
|
|
self.tokenizer = tokenizer
|
|
|
self.data_args = data_args
|
|
|
self.mask_format = 'polygon'
|
|
|
coco_class_ids = [
|
|
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17,
|
|
|
18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
|
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49,
|
|
|
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
|
|
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
|
|
82, 84, 85, 86, 87, 88, 89, 90
|
|
|
]
|
|
|
coco_class_name = [
|
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
|
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
|
|
'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
|
|
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
|
|
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
|
|
|
'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
|
|
'skateboard', 'surfboard', 'tennis racket', 'bottle',
|
|
|
'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
|
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli',
|
|
|
'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
|
|
'couch', 'potted plant', 'bed', 'dining table', 'toilet',
|
|
|
'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book',
|
|
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
|
|
]
|
|
|
self.coco_id_to_cont_id = {coco_id: cont_id for cont_id, coco_id in enumerate(coco_class_ids)}
|
|
|
self.coco_class_name = coco_class_name + ['background']
|
|
|
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in self.coco_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
return class_name_id, cls_indices
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_file = data['image']
|
|
|
image_folder = self.data_args.image_folder
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict['height'] = data['image_info']['height']
|
|
|
data_dict['width'] = data['image_info']['width']
|
|
|
data_dict['image_id'] = data['new_img_id']
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
for annotation in data_dict['annotations']:
|
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
|
|
if annotation['category_id'] in self.coco_id_to_cont_id:
|
|
|
annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
|
|
elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
annotation['category_id'] = annotation['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
annotation['image_id'] = data['new_img_id']
|
|
|
|
|
|
if isinstance(self.data_args.image_processor, dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
|
|
|
instruction = 'Panoptic Segmentation: You need to segment all objects '
|
|
|
prefix_inst = 'This is an image <image>, Please do Panoptic Segmentation.'
|
|
|
|
|
|
num_class = len(self.coco_class_name)
|
|
|
category = '<cls>, ' * (num_class - 1) + '<cls>.'
|
|
|
|
|
|
sources_value = f'\nThis is all the candidate categories: {category}\n'
|
|
|
|
|
|
sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
|
|
|
|
|
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
|
|
|
class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
|
|
class_name_embedding_indices = torch.zeros_like(input_ids)
|
|
|
class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
|
|
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
|
|
|
data_dict['class_name_ids'] = class_name_ids
|
|
|
data_dict['cls_indices'] = cls_indices
|
|
|
data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
|
|
|
class COCO_panoptic_dataset_random(COCO_panoptic_dataset):
|
|
|
def preprocess_class_name(self, CLS_token='[CAT]'):
|
|
|
random_idx = list(range(len(self.coco_class_name)))
|
|
|
random.shuffle(random_idx)
|
|
|
random_class_name = [self.coco_class_name[i] for i in random_idx]
|
|
|
permute_idx = list(sorted(range(len(random_idx)), key=random_idx.__getitem__))
|
|
|
tokenized = [self.tokenizer.encode(class_name, add_special_tokens=False) for class_name in random_class_name]
|
|
|
tokenized_class_names = [tokens + [self.tokenizer.encode(CLS_token, add_special_tokens=False)[0]] for tokens in
|
|
|
tokenized]
|
|
|
class_name_id = [token for sublist in tokenized_class_names for token in sublist]
|
|
|
class_name_id = torch.tensor(class_name_id)
|
|
|
cls_indices = [idx for idx, sublist in enumerate(tokenized_class_names) for _ in sublist]
|
|
|
cls_indices = torch.tensor(cls_indices)
|
|
|
|
|
|
permute_idx = torch.tensor(permute_idx)
|
|
|
|
|
|
|
|
|
return class_name_id, cls_indices, permute_idx
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_id = int(data["image_id"])
|
|
|
image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = image_file
|
|
|
data_dict['image_id'] = image_id
|
|
|
label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
|
|
sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
|
|
data_dict['pan_seg_file_name'] = label_file
|
|
|
data_dict['sem_seg_file_name'] = sem_label_file
|
|
|
segments_info = data["segments_info"]
|
|
|
for seg in segments_info:
|
|
|
if seg['category_id'] in self.coco_id_to_cont_id:
|
|
|
seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
|
|
elif seg['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
seg['category_id'] = seg['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
data_dict['segments_info'] = segments_info
|
|
|
|
|
|
|
|
|
|
|
|
processor = self.data_args.image_processor['panoptic']
|
|
|
data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
|
|
|
|
|
instruction = 'Panoptic Segmentation: You need to segment all objects '
|
|
|
|
|
|
num_class = len(self.coco_class_name)
|
|
|
category = '<cls>, ' * (num_class-1) + '<cls>.'
|
|
|
|
|
|
sources_value = f'This is all the candidate categories: {category}\n<image>\n'
|
|
|
|
|
|
sources = [[{'from': 'human', 'value': sources_value + instruction},
|
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
|
|
|
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
|
|
|
class_name_ids, cls_indices, random_idx = self.preprocess_class_name()
|
|
|
data_dict['random_idx'] = random_idx
|
|
|
class_name_embedding_indices = torch.zeros_like(input_ids)
|
|
|
class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
|
|
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
data_dict['dataset_type'] = 'panoptic_coco'
|
|
|
|
|
|
data_dict['class_name_ids'] = class_name_ids
|
|
|
data_dict['cls_indices'] = cls_indices
|
|
|
data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
return data_dict
|
|
|
|
|
|
class COCO_semantic_dataset(COCO_panoptic_dataset):
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_id = int(data["image_id"])
|
|
|
image_file = os.path.join(self.panoptic_image_path, os.path.splitext(data["file_name"])[0] + ".jpg")
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = image_file
|
|
|
data_dict['image_id'] = image_id
|
|
|
label_file = os.path.join(self.panoptic_gt_path, data["file_name"])
|
|
|
sem_label_file = os.path.join(self.semantic_gt_path, data["file_name"])
|
|
|
data_dict['pan_seg_file_name'] = sem_label_file
|
|
|
data_dict['sem_seg_file_name'] = sem_label_file
|
|
|
segments_info = data["segments_info"]
|
|
|
for seg in segments_info:
|
|
|
seg['category_id'] = self.coco_id_to_cont_id[seg['category_id']]
|
|
|
data_dict['segments_info'] = segments_info
|
|
|
|
|
|
if isinstance(self.data_args.image_processor, dict):
|
|
|
processor = self.data_args.image_processor['panoptic']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
|
|
|
|
|
instruction = 'Panoptic Segmentation: You need to segment all objects '
|
|
|
prefix_inst = 'This is an image <image>, Please do Semantic Segmentation.'
|
|
|
|
|
|
num_class = len(self.coco_class_name)
|
|
|
category = '<cls>, ' * (num_class-1) + '<cls>.'
|
|
|
|
|
|
sources_value = f'\nThis is all the candidate categories: {category}\n'
|
|
|
|
|
|
sources = [[{'from': 'human', 'value': prefix_inst + sources_value},
|
|
|
{'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
|
|
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
|
|
|
class_name_ids, cls_indices = self.preprocess_class_name(CLS_token='[SEG]')
|
|
|
class_name_embedding_indices = torch.zeros_like(input_ids)
|
|
|
class_name_embedding_indices[input_ids == CLS_TOKEN_INDEX] = 1
|
|
|
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
|
|
|
data_dict['class_name_ids'] = class_name_ids
|
|
|
data_dict['cls_indices'] = cls_indices
|
|
|
data_dict['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
|
|
|
class RefCOCO_dataset(COCO_instance_dataset):
|
|
|
|
|
|
def preprocess_referring_instruction(self,instruction, REFER_token='[SEG]'):
|
|
|
tokenized = self.tokenizer.encode(instruction, add_special_tokens=False)
|
|
|
tokenized = tokenized + [self.tokenizer.encode(REFER_token, add_special_tokens=False)[0]]
|
|
|
|
|
|
token_refer_id = torch.tensor(tokenized)
|
|
|
|
|
|
return token_refer_id
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX,
|
|
|
region_token_index=REGION_TOKEN_INDEX,refer_token_index=REFER_TOKEN_INDEX, return_tensors=None):
|
|
|
input_ids = []
|
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index, '<refer>':refer_token_index}
|
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>|<refer>)', prompt)
|
|
|
|
|
|
for chunk in prompt_chunks:
|
|
|
if chunk in special_token_map:
|
|
|
input_ids.append(special_token_map[chunk])
|
|
|
else:
|
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False))
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
else:
|
|
|
return input_ids
|
|
|
def __getitem__(self, idx):
|
|
|
data = self.data[idx]
|
|
|
image_file = data['image_info']['file_name']
|
|
|
image_folder = self.data_args.refcoco_image_folder
|
|
|
|
|
|
data_dict = {}
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict['height'] = data['image_info']['height']
|
|
|
data_dict['width'] = data['image_info']['width']
|
|
|
data_dict['image_id'] = data['new_img_id']
|
|
|
data_dict['annotations'] = data['anns']
|
|
|
for annotation in data_dict['annotations']:
|
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS
|
|
|
|
|
|
if annotation['category_id'] in self.coco_id_to_cont_id:
|
|
|
annotation['category_id'] = self.coco_id_to_cont_id[annotation['category_id']]
|
|
|
elif annotation['category_id'] in self.coco_id_to_cont_id.values():
|
|
|
annotation['category_id'] = annotation['category_id']
|
|
|
else:
|
|
|
raise ValueError
|
|
|
annotation['image_id'] = data['new_img_id']
|
|
|
|
|
|
if isinstance(self.data_args.image_processor,dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
data_dict = processor.preprocess(data_dict, mask_format=self.mask_format)
|
|
|
|
|
|
sentences = data['instruction']
|
|
|
|
|
|
prefix_inst = 'This is an image <image>, Please doing Referring Segmentation according to the following instruction:'
|
|
|
instruction = ''
|
|
|
for sent in sentences:
|
|
|
instruction += ' {}.'.format(sent['sent'])
|
|
|
sources = [[{'from': 'human', 'value': prefix_inst + '\n<refer>'},
|
|
|
{'from': 'gpt', 'value': '\nSure, the segmentation result is <seg>'}]]
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
input_ids = text_dict['input_ids'][0]
|
|
|
labels = text_dict['labels'][0]
|
|
|
|
|
|
token_refer_id = self.preprocess_referring_instruction(instruction)
|
|
|
refer_embedding_indices = torch.zeros_like(input_ids)
|
|
|
refer_embedding_indices[input_ids == REFER_TOKEN_INDEX] = 1
|
|
|
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
data_dict['dataset_type'] = 'referring_coco'
|
|
|
|
|
|
data_dict['token_refer_id'] = token_refer_id
|
|
|
data_dict['refer_embedding_indices'] = refer_embedding_indices
|
|
|
return data_dict
|
|
|
|
|
|
def preprocess_multimodal(
|
|
|
sources,
|
|
|
data_args
|
|
|
):
|
|
|
is_multimodal = data_args.is_multimodal
|
|
|
if not is_multimodal:
|
|
|
return sources
|
|
|
|
|
|
for source in sources:
|
|
|
for sentence in source:
|
|
|
if DEFAULT_IMAGE_TOKEN in sentence['value']:
|
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
|
|
|
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
|
|
|
sentence['value'] = sentence['value'].strip()
|
|
|
if "mmtag" in conversation_lib.default_conversation.version:
|
|
|
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
|
|
|
'<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
|
|
|
replace_token = DEFAULT_IMAGE_TOKEN
|
|
|
if data_args.mm_use_im_start_end:
|
|
|
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
|
|
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
|
|
|
|
|
return sources
|
|
|
|
|
|
class UnifyDatasetSingleDatasetForBatch(Dataset):
|
|
|
"""
|
|
|
Dataset to concatenate multiple datasets.
|
|
|
Purpose: useful to assemble different existing datasets, possibly
|
|
|
large-scale datasets as the concatenation operation is done in an
|
|
|
on-the-fly manner.
|
|
|
Arguments:
|
|
|
datasets (sequence): List of datasets to be concatenated
|
|
|
"""
|
|
|
|
|
|
@staticmethod
|
|
|
def cumsum(sequence):
|
|
|
r, s = [], 0
|
|
|
for e in sequence:
|
|
|
l = len(e)
|
|
|
r.append(l + s)
|
|
|
s += l
|
|
|
return r
|
|
|
|
|
|
|
|
|
def __init__(self,datasets,dataset_ratio,bs,fix_dataset_len=0):
|
|
|
super(UnifyDatasetSingleDatasetForBatch, self).__init__()
|
|
|
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
|
|
self.fix_dataset_len = fix_dataset_len
|
|
|
|
|
|
self.cnt = 0
|
|
|
self.bs = bs
|
|
|
|
|
|
self.datasets = list(datasets)
|
|
|
self.datasets_index_list = list(range(len(datasets)))
|
|
|
self.dataset_ratio = dataset_ratio
|
|
|
self.cur_dataset_index=0
|
|
|
self.dataset_length = [len(data) for data in self.datasets]
|
|
|
self.cumulative_sizes = self.cumsum(self.datasets)
|
|
|
self.coco_id_to_cont_id = {}
|
|
|
self.coco_class_name = {}
|
|
|
for _dataset in self.datasets:
|
|
|
dataset_coco_id_to_cont_id = _dataset.coco_id_to_cont_id if hasattr(_dataset,'coco_id_to_cont_id') else []
|
|
|
if len(dataset_coco_id_to_cont_id) > len(self.coco_id_to_cont_id):
|
|
|
self.coco_id_to_cont_id = dataset_coco_id_to_cont_id
|
|
|
for _dataset in self.datasets:
|
|
|
_dataset.coco_id_to_cont_id = self.coco_id_to_cont_id
|
|
|
for _dataset in self.datasets:
|
|
|
dataset_coco_class_name = _dataset.coco_class_name if hasattr(_dataset,'coco_class_name') else []
|
|
|
if len(dataset_coco_class_name) > len(self.coco_class_name):
|
|
|
self.coco_class_name = dataset_coco_class_name
|
|
|
for _dataset in self.datasets:
|
|
|
_dataset.coco_class_name = self.coco_class_name
|
|
|
|
|
|
|
|
|
|
|
|
def update_dataset_index(self):
|
|
|
tempt = self.cur_dataset_index
|
|
|
tempt += 1
|
|
|
tempt = tempt % len(self.datasets)
|
|
|
self.cur_dataset_index = tempt
|
|
|
|
|
|
def __len__(self):
|
|
|
if self.fix_dataset_len == 0:
|
|
|
return self.cumulative_sizes[-1]
|
|
|
else:
|
|
|
return self.fix_dataset_len
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
cur_dataset_len = self.dataset_length[self.cur_dataset_index]
|
|
|
data_idx = idx % cur_dataset_len
|
|
|
output_data = self.datasets[self.cur_dataset_index][data_idx]
|
|
|
self.cnt += 1
|
|
|
if self.cnt == self.bs:
|
|
|
self.cnt = 0
|
|
|
self.update_dataset_index()
|
|
|
return output_data
|
|
|
|
|
|
|
|
|
|
|
|
class MM_Conv_Dataset(Dataset):
|
|
|
def __init__(self, data_path,
|
|
|
tokenizer,
|
|
|
data_args):
|
|
|
super(MM_Conv_Dataset, self).__init__()
|
|
|
list_data_dict = json.load(open(data_path, "r"))
|
|
|
|
|
|
print("Formatting inputs...Skip in lazy mode")
|
|
|
self.tokenizer = tokenizer
|
|
|
self.list_data_dict = list_data_dict
|
|
|
self.data_args = data_args
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.list_data_dict)
|
|
|
|
|
|
def preprocess_llama2(self, sources, tokenizer):
|
|
|
conv = conversation_lib.default_conversation.copy()
|
|
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
|
|
|
|
|
|
|
|
conversations = []
|
|
|
for i, source in enumerate(sources):
|
|
|
if roles[source[0]["from"]] != conv.roles[0]:
|
|
|
|
|
|
source = source[1:]
|
|
|
|
|
|
conv.messages = []
|
|
|
for j, sentence in enumerate(source):
|
|
|
role = roles[sentence["from"]]
|
|
|
assert role == conv.roles[j % 2], f"{i}"
|
|
|
conv.append_message(role, sentence["value"])
|
|
|
conversations.append(conv.get_prompt())
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.stack(
|
|
|
[self.tokenizer_special_tokens(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
|
|
|
|
|
|
targets = input_ids.clone()
|
|
|
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
|
|
|
|
|
|
|
|
|
sep = "[/INST] "
|
|
|
idx = 0
|
|
|
for conversation, target in zip(conversations, targets):
|
|
|
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
|
|
|
|
|
rounds = conversation.split(conv.sep2)
|
|
|
if conv.version == 'phi':
|
|
|
cur_len = 0
|
|
|
target[:cur_len] = IGNORE_INDEX
|
|
|
idx = 0
|
|
|
for i, rou in enumerate(rounds):
|
|
|
if rou == "":
|
|
|
continue
|
|
|
|
|
|
parts = rou.split(sep)
|
|
|
if len(parts) != 2:
|
|
|
break
|
|
|
parts[0] += sep
|
|
|
if idx > 0:
|
|
|
round_len = len(self.tokenizer_special_tokens(rou, tokenizer)) + 2
|
|
|
else:
|
|
|
round_len = len(self.tokenizer_special_tokens(rou, tokenizer)) + 1
|
|
|
if idx > 0:
|
|
|
instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer))
|
|
|
else:
|
|
|
instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 1
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
|
|
|
|
|
cur_len += round_len
|
|
|
idx += 1
|
|
|
target[cur_len:] = IGNORE_INDEX
|
|
|
else:
|
|
|
cur_len = 1
|
|
|
target[:cur_len] = IGNORE_INDEX
|
|
|
for i, rou in enumerate(rounds):
|
|
|
if rou == "":
|
|
|
continue
|
|
|
|
|
|
parts = rou.split(sep)
|
|
|
if len(parts) != 2:
|
|
|
break
|
|
|
parts[0] += sep
|
|
|
round_len = len(self.tokenizer_special_tokens(rou, tokenizer))
|
|
|
instruction_len = len(self.tokenizer_special_tokens(parts[0], tokenizer)) - 2
|
|
|
|
|
|
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
|
|
|
|
|
cur_len += round_len
|
|
|
idx += 1
|
|
|
target[cur_len:] = IGNORE_INDEX
|
|
|
|
|
|
if cur_len < tokenizer.model_max_length:
|
|
|
if cur_len != total_len:
|
|
|
target[:] = IGNORE_INDEX
|
|
|
return dict(
|
|
|
input_ids=input_ids,
|
|
|
labels=targets,
|
|
|
)
|
|
|
|
|
|
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
|
|
|
seg_token_index=SEG_TOKEN_INDEX, return_tensors=None):
|
|
|
prompt_chunks = []
|
|
|
special_tokens = []
|
|
|
image_splits = prompt.split('<image>')
|
|
|
|
|
|
for i, chunk in enumerate(image_splits):
|
|
|
if i != 0:
|
|
|
special_tokens.append('<image>')
|
|
|
seg_splits = chunk.split('<seg>')
|
|
|
prompt_chunks.extend(seg_splits)
|
|
|
special_tokens.extend(['<seg>'] * (len(seg_splits)-1))
|
|
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
|
|
|
special_indexes = [image_token_index if token == '<image>' else seg_token_index for token in special_tokens]
|
|
|
input_ids = []
|
|
|
for i, chunk in enumerate(prompt_chunks):
|
|
|
input_ids.extend(chunk)
|
|
|
if i != len(prompt_chunks) -1:
|
|
|
input_ids.extend([special_indexes[i]])
|
|
|
if return_tensors is not None:
|
|
|
if return_tensors == 'pt':
|
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze()
|
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
|
|
return input_ids
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
|
|
sources = self.list_data_dict[i]
|
|
|
|
|
|
|
|
|
sources = [sources]
|
|
|
assert len(sources) == 1, "Don't know why it is wrapped to a list"
|
|
|
data_dict = {}
|
|
|
if 'image' in sources[0]:
|
|
|
image_file = self.list_data_dict[i]['image']
|
|
|
image_folder = self.data_args.mmconv_path
|
|
|
if isinstance(self.data_args.image_processor, dict):
|
|
|
processor = self.data_args.image_processor['instance']
|
|
|
else:
|
|
|
processor = self.data_args.image_processor
|
|
|
if 'coco' in image_file:
|
|
|
image_folder = self.data_args.image_folder
|
|
|
image_file = os.path.basename(image_file)
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
else:
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file)
|
|
|
data_dict = processor.preprocess(data_dict)
|
|
|
|
|
|
sources = preprocess_multimodal(
|
|
|
copy.deepcopy([e["conversations"] for e in sources]),
|
|
|
self.data_args)
|
|
|
else:
|
|
|
sources = copy.deepcopy([e["conversations"] for e in sources])
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer)
|
|
|
data_dict['input_ids'] = text_dict['input_ids'][0]
|
|
|
data_dict['labels'] = text_dict['labels'][0]
|
|
|
data_dict['dataset_type'] = 'mm_conv'
|
|
|
if 'image' not in data_dict:
|
|
|
|
|
|
crop_size = 1024
|
|
|
data_dict['image'] = torch.zeros(3, crop_size, crop_size)
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class DataCollatorForCOCODatasetV2(object):
|
|
|
"""Collate examples for supervised fine-tuning."""
|
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
|
input_ids, labels = tuple([instance[key] for instance in instances]
|
|
|
for key in ("input_ids", "labels"))
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
|
input_ids,
|
|
|
batch_first=True,
|
|
|
padding_value=self.tokenizer.pad_token_id)
|
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
|
|
batch_first=True,
|
|
|
padding_value=IGNORE_INDEX)
|
|
|
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
|
|
labels = labels[:, :self.tokenizer.model_max_length]
|
|
|
batch = dict(
|
|
|
input_ids=input_ids,
|
|
|
labels=labels,
|
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
|
)
|
|
|
if 'image' in instances[0]:
|
|
|
images = [instance['image'] for instance in instances]
|
|
|
if all(x is not None and x.shape == images[0].shape for x in images):
|
|
|
batch['images'] = torch.stack(images)
|
|
|
else:
|
|
|
batch['images'] = images
|
|
|
|
|
|
if 'vp_image' in instances[0]:
|
|
|
vp_images = [instance['vp_image'] for instance in instances]
|
|
|
if all(x is not None and x.shape == vp_images[0].shape for x in vp_images):
|
|
|
batch['vp_images'] = torch.stack(vp_images)
|
|
|
else:
|
|
|
batch['vp_images'] = vp_images
|
|
|
for instance in instances:
|
|
|
for key in ['input_ids', 'labels', 'image']:
|
|
|
del instance[key]
|
|
|
batch['seg_info'] = [instance for instance in instances]
|
|
|
|
|
|
if 'dataset_type' in instances[0]:
|
|
|
batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
|
|
|
|
|
|
if 'class_name_ids' in instances[0]:
|
|
|
class_name_ids = [instance['class_name_ids'] for instance in instances]
|
|
|
if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
|
|
|
batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
|
|
|
class_name_ids,
|
|
|
batch_first=True,
|
|
|
padding_value=-1,
|
|
|
)
|
|
|
else:
|
|
|
batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
|
|
|
if 'token_refer_id' in instances[0]:
|
|
|
token_refer_id = [instance['token_refer_id'] for instance in instances]
|
|
|
batch['token_refer_id'] = token_refer_id
|
|
|
if 'cls_indices' in instances[0]:
|
|
|
cls_indices = [instance['cls_indices'] for instance in instances]
|
|
|
if any(x.shape != cls_indices[0].shape for x in cls_indices):
|
|
|
batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
|
|
|
cls_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=-1,
|
|
|
)
|
|
|
else:
|
|
|
batch['cls_indices'] = torch.stack(cls_indices, dim=0)
|
|
|
if 'random_idx' in instances[0]:
|
|
|
random_idxs = [instance['random_idx'] for instance in instances]
|
|
|
batch['random_idx'] = torch.stack(random_idxs, dim=0)
|
|
|
if 'class_name_embedding_indices' in instances[0]:
|
|
|
class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
|
|
|
class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
|
|
class_name_embedding_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=0)
|
|
|
batch['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
if 'refer_embedding_indices' in instances[0]:
|
|
|
refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
|
|
|
refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
|
|
refer_embedding_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=0)
|
|
|
batch['refer_embedding_indices'] = refer_embedding_indices
|
|
|
|
|
|
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class DataCollatorForCOCODatasetV2_old(object):
|
|
|
"""Collate examples for supervised fine-tuning."""
|
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer
|
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
|
|
input_ids, labels = tuple([instance[key] for instance in instances]
|
|
|
for key in ("input_ids", "labels"))
|
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
|
input_ids,
|
|
|
batch_first=True,
|
|
|
padding_value=self.tokenizer.pad_token_id)
|
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
|
|
batch_first=True,
|
|
|
padding_value=IGNORE_INDEX)
|
|
|
input_ids = input_ids[:, :self.tokenizer.model_max_length]
|
|
|
labels = labels[:, :self.tokenizer.model_max_length]
|
|
|
batch = dict(
|
|
|
input_ids=input_ids,
|
|
|
labels=labels,
|
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
|
|
)
|
|
|
if 'image' in instances[0]:
|
|
|
images = [instance['image'] for instance in instances]
|
|
|
if all(x is not None and x.shape == images[0].shape for x in images):
|
|
|
batch['images'] = torch.stack(images)
|
|
|
else:
|
|
|
batch['images'] = images
|
|
|
for instance in instances:
|
|
|
for key in ['input_ids', 'labels', 'image']:
|
|
|
del instance[key]
|
|
|
batch['seg_info'] = [instance for instance in instances]
|
|
|
|
|
|
if 'dataset_type' in instances[0]:
|
|
|
batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
|
|
|
|
|
|
if 'class_name_ids' in instances[0]:
|
|
|
class_name_ids = [instance['class_name_ids'] for instance in instances]
|
|
|
if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
|
|
|
batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
|
|
|
class_name_ids,
|
|
|
batch_first=True,
|
|
|
padding_value=-1,
|
|
|
)
|
|
|
else:
|
|
|
batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
|
|
|
if 'token_refer_id' in instances[0]:
|
|
|
token_refer_id = [instance['token_refer_id'] for instance in instances]
|
|
|
batch['token_refer_id'] = token_refer_id
|
|
|
if 'cls_indices' in instances[0]:
|
|
|
cls_indices = [instance['cls_indices'] for instance in instances]
|
|
|
if any(x.shape != cls_indices[0].shape for x in cls_indices):
|
|
|
batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
|
|
|
cls_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=-1,
|
|
|
)
|
|
|
else:
|
|
|
batch['cls_indices'] = torch.stack(cls_indices, dim=0)
|
|
|
if 'random_idx' in instances[0]:
|
|
|
random_idxs = [instance['random_idx'] for instance in instances]
|
|
|
batch['random_idx'] = torch.stack(random_idxs, dim=0)
|
|
|
if 'class_name_embedding_indices' in instances[0]:
|
|
|
class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
|
|
|
class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
|
|
class_name_embedding_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=0)
|
|
|
batch['class_name_embedding_indices'] = class_name_embedding_indices
|
|
|
if 'refer_embedding_indices' in instances[0]:
|
|
|
refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
|
|
|
refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
|
|
|
refer_embedding_indices,
|
|
|
batch_first=True,
|
|
|
padding_value=0)
|
|
|
batch['refer_embedding_indices'] = refer_embedding_indices
|
|
|
|
|
|
return batch
|
|
|
|