zhouyik commited on
Commit
4ee9c8f
·
verified ·
1 Parent(s): ef577d4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__pycache__/configuration_intern_vit.cpython-310.pyc ADDED
Binary file (5.03 kB). View file
 
__pycache__/configuration_internlm2.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
__pycache__/configuration_mask2former.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
__pycache__/configuration_phi3.cpython-310.pyc ADDED
Binary file (8.67 kB). View file
 
__pycache__/configuration_sa2va_chat.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
__pycache__/constants.cpython-310.pyc ADDED
Binary file (555 Bytes). View file
 
__pycache__/flash_attention.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
__pycache__/mask2former.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
__pycache__/modeling_intern_vit.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
__pycache__/modeling_internlm2.cpython-310.pyc ADDED
Binary file (42.9 kB). View file
 
__pycache__/modeling_phi3.cpython-310.pyc ADDED
Binary file (44.2 kB). View file
 
__pycache__/modeling_sa2va_chat.cpython-310.pyc ADDED
Binary file (27.1 kB). View file
 
__pycache__/templates.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
added_tokens.json ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</box>": 151673,
3
+ "</img>": 151666,
4
+ "</obj>": 151679,
5
+ "</p>": 151675,
6
+ "</quad>": 151669,
7
+ "</ref>": 151671,
8
+ "</tool_call>": 151658,
9
+ "<IMG_CONTEXT>": 151667,
10
+ "<OBJ_CONTEXT>": 151680,
11
+ "<box>": 151672,
12
+ "<img>": 151665,
13
+ "<obj>": 151678,
14
+ "<p>": 151674,
15
+ "<quad>": 151668,
16
+ "<ref>": 151670,
17
+ "<tool_call>": 151657,
18
+ "<|box_end|>": 151649,
19
+ "<|box_start|>": 151648,
20
+ "<|endoftext|>": 151643,
21
+ "<|file_sep|>": 151664,
22
+ "<|fim_middle|>": 151660,
23
+ "<|fim_pad|>": 151662,
24
+ "<|fim_prefix|>": 151659,
25
+ "<|fim_suffix|>": 151661,
26
+ "<|im_end|>": 151645,
27
+ "<|im_start|>": 151644,
28
+ "<|image_pad|>": 151655,
29
+ "<|object_ref_end|>": 151647,
30
+ "<|object_ref_start|>": 151646,
31
+ "<|quad_end|>": 151651,
32
+ "<|quad_start|>": 151650,
33
+ "<|repo_name|>": 151663,
34
+ "<|video_pad|>": 151656,
35
+ "<|vision_end|>": 151653,
36
+ "<|vision_pad|>": 151654,
37
+ "<|vision_start|>": 151652,
38
+ "[BG_CLS]": 151677,
39
+ "[CLS]": 151676,
40
+ "[SEG000]": 151681,
41
+ "[SEG001]": 151682,
42
+ "[SEG002]": 151683,
43
+ "[SEG003]": 151684,
44
+ "[SEG004]": 151685,
45
+ "[SEG005]": 151686,
46
+ "[SEG006]": 151687,
47
+ "[SEG007]": 151688,
48
+ "[SEG008]": 151689,
49
+ "[SEG009]": 151690,
50
+ "[SEG010]": 151691,
51
+ "[SEG011]": 151692,
52
+ "[SEG012]": 151693,
53
+ "[SEG013]": 151694,
54
+ "[SEG014]": 151695,
55
+ "[SEG015]": 151696,
56
+ "[SEG016]": 151697,
57
+ "[SEG017]": 151698,
58
+ "[SEG018]": 151699,
59
+ "[SEG019]": 151700,
60
+ "[SEG020]": 151701,
61
+ "[SEG021]": 151702,
62
+ "[SEG022]": 151703,
63
+ "[SEG023]": 151704,
64
+ "[SEG024]": 151705,
65
+ "[SEG025]": 151706,
66
+ "[SEG026]": 151707,
67
+ "[SEG027]": 151708,
68
+ "[SEG028]": 151709,
69
+ "[SEG029]": 151710,
70
+ "[SEG030]": 151711,
71
+ "[SEG031]": 151712,
72
+ "[SEG032]": 151713,
73
+ "[SEG033]": 151714,
74
+ "[SEG034]": 151715,
75
+ "[SEG035]": 151716,
76
+ "[SEG036]": 151717,
77
+ "[SEG037]": 151718,
78
+ "[SEG038]": 151719,
79
+ "[SEG039]": 151720,
80
+ "[SEG040]": 151721,
81
+ "[SEG041]": 151722,
82
+ "[SEG042]": 151723,
83
+ "[SEG043]": 151724,
84
+ "[SEG044]": 151725,
85
+ "[SEG045]": 151726,
86
+ "[SEG046]": 151727,
87
+ "[SEG047]": 151728,
88
+ "[SEG048]": 151729,
89
+ "[SEG049]": 151730,
90
+ "[SEG050]": 151731,
91
+ "[SEG051]": 151732,
92
+ "[SEG052]": 151733,
93
+ "[SEG053]": 151734,
94
+ "[SEG054]": 151735,
95
+ "[SEG055]": 151736,
96
+ "[SEG056]": 151737,
97
+ "[SEG057]": 151738,
98
+ "[SEG058]": 151739,
99
+ "[SEG059]": 151740,
100
+ "[SEG060]": 151741,
101
+ "[SEG061]": 151742,
102
+ "[SEG062]": 151743,
103
+ "[SEG063]": 151744,
104
+ "[SEG064]": 151745,
105
+ "[SEG065]": 151746,
106
+ "[SEG066]": 151747,
107
+ "[SEG067]": 151748,
108
+ "[SEG068]": 151749,
109
+ "[SEG069]": 151750,
110
+ "[SEG070]": 151751,
111
+ "[SEG071]": 151752,
112
+ "[SEG072]": 151753,
113
+ "[SEG073]": 151754,
114
+ "[SEG074]": 151755,
115
+ "[SEG075]": 151756,
116
+ "[SEG076]": 151757,
117
+ "[SEG077]": 151758,
118
+ "[SEG078]": 151759,
119
+ "[SEG079]": 151760,
120
+ "[SEG080]": 151761,
121
+ "[SEG081]": 151762,
122
+ "[SEG082]": 151763,
123
+ "[SEG083]": 151764,
124
+ "[SEG084]": 151765,
125
+ "[SEG085]": 151766,
126
+ "[SEG086]": 151767,
127
+ "[SEG087]": 151768,
128
+ "[SEG088]": 151769,
129
+ "[SEG089]": 151770,
130
+ "[SEG090]": 151771,
131
+ "[SEG091]": 151772,
132
+ "[SEG092]": 151773,
133
+ "[SEG093]": 151774,
134
+ "[SEG094]": 151775,
135
+ "[SEG095]": 151776,
136
+ "[SEG096]": 151777,
137
+ "[SEG097]": 151778,
138
+ "[SEG098]": 151779,
139
+ "[SEG099]": 151780
140
+ }
chat_with_sa2va.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import os
4
+ import numpy as np
5
+
6
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
7
+
8
+ from types import MethodType
9
+ from detectron2.data import MetadataCatalog
10
+ from detectron2.utils.visualizer import ColorMode, Visualizer
11
+
12
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
13
+ from detectron2.data.detection_utils import read_image
14
+ from detectron2.utils.visualizer import GenericMask
15
+ import matplotlib.colors as mplc
16
+ def draw_instance_predictions_cache(self, labels, np_masks, jittering: bool = True):
17
+ """
18
+ Draw instance-level prediction results on an image.
19
+
20
+ Args:
21
+ predictions (Instances): the output of an instance detection/segmentation
22
+ model. Following fields will be used to draw:
23
+ "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").
24
+ jittering: if True, in color mode SEGMENTATION, randomly jitter the colors per class
25
+ to distinguish instances from the same class
26
+
27
+ Returns:
28
+ output (VisImage): image object with visualizations.
29
+ """
30
+ boxes = None
31
+ scores = None
32
+ classes = None
33
+ keypoints = None
34
+
35
+ masks = [GenericMask(x, self.output.height, self.output.width) for x in np_masks]
36
+
37
+
38
+ if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"):
39
+ colors = (
40
+ [self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes]
41
+ if jittering
42
+ else [
43
+ tuple(mplc.to_rgb([x / 255 for x in self.metadata.thing_colors[c]]))
44
+ for c in classes
45
+ ]
46
+ )
47
+
48
+ alpha = 0.8
49
+ else:
50
+ colors = None
51
+ alpha = 0.5
52
+
53
+ self.overlay_instances(
54
+ masks=masks,
55
+ boxes=boxes,
56
+ labels=labels,
57
+ keypoints=keypoints,
58
+ assigned_colors=colors,
59
+ alpha=alpha,
60
+ )
61
+ return self.output
62
+
63
+
64
+ def visualize(image_path, cat_masks, out_path, tags):
65
+ if tags is None:
66
+ left_tags = [f'{i}' for i in range(len(cat_masks))]
67
+ else:
68
+ left_tags = tags
69
+
70
+ unique_tags = list(set(left_tags))
71
+ text_prompt = ','.join(unique_tags)
72
+ metadata = MetadataCatalog.get("__unused_ape_" + text_prompt)
73
+ metadata.thing_classes = unique_tags
74
+ metadata.stuff_classes = unique_tags
75
+
76
+ result_masks = cat_masks
77
+ input_image = read_image(image_path, format="BGR")
78
+ visualizer = Visualizer(input_image[:, :, ::-1], metadata, instance_mode=ColorMode.IMAGE)
79
+ visualizer.draw_instance_predictions = MethodType(draw_instance_predictions_cache, visualizer)
80
+ vis_output = visualizer.draw_instance_predictions(labels=left_tags, np_masks=result_masks)
81
+ output_image = vis_output.get_image()
82
+ output_image = Image.fromarray(output_image)
83
+
84
+ output_image.save(out_path)
85
+
86
+ path = "./work_dirs/hf_pano_vlm"
87
+ model = AutoModel.from_pretrained(
88
+ path,
89
+ torch_dtype=torch.bfloat16,
90
+ low_cpu_mem_usage=True,
91
+ use_flash_attn=True,
92
+ trust_remote_code=True).eval().cuda()
93
+
94
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
95
+
96
+ image_path = "./FRAME02_ORI.jpg"
97
+ image = Image.open(image_path)
98
+ width, height = image.size
99
+
100
+ from projects.llava_sam2.datasets.coco_category import COCO_CATEGORIES
101
+ coco_category_names = ""
102
+ for item in COCO_CATEGORIES:
103
+ class_name = item['name']
104
+ coco_category_names += f"<p>{class_name}</p> [CLS], "
105
+ coco_category_names = coco_category_names[:-2]
106
+ # question = f"<image>\nSegment from the class prompt: {coco_category_names}."
107
+ question = f"<image>\nSegment from the class prompt: <p>person</p> [CLS], <p>car</p> [CLS], <p>road</p> [CLS], <p>tree</p> [CLS], <p>building</p> [CLS], <p>ground</p> [CLS]."
108
+
109
+ m2f_processor = AutoImageProcessor.from_pretrained("./facebook/mask2former-swin-large-coco-panoptic", trust_remote_code=True,)
110
+
111
+ chat_outputs = model.predict_forward(text=question, image=image, tokenizer=tokenizer, m2f_processor=m2f_processor)
112
+ answer = chat_outputs['prediction']
113
+ masks = chat_outputs['prediction_masks']
114
+
115
+ m2f_outputs = chat_outputs['m2f_outputs']
116
+
117
+ label_id_to_text = m2f_outputs['label_id_to_text']
118
+
119
+ post_m2f_outputs = model.post_process_panoptic_segmentation(
120
+ m2f_outputs['class_queries_logits'],
121
+ m2f_outputs['masks_queries_logits'],
122
+ target_sizes=[(height, width)],
123
+ )
124
+
125
+ print(f"user: {question}")
126
+ print(f"assistant: {answer}")
127
+
128
+ segmentation = post_m2f_outputs[0]['segmentation']
129
+ segments_info = post_m2f_outputs[0]['segments_info']
130
+ pano_masks, pano_tags = [], []
131
+ for item in segments_info:
132
+ mask = segmentation == item['id']
133
+ pano_masks.append(mask.unsqueeze(0).cpu().numpy())
134
+ pano_tags.append(label_id_to_text[item['label_id']])
135
+
136
+ pano_masks = np.concatenate(pano_masks, axis=0)
137
+
138
+ visualize(image_path, pano_masks, "./visualize_test_4.jpg", pano_tags)
139
+
140
+
config.json ADDED
@@ -0,0 +1,2677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "./OpenGVLab/InternVL2_5-4B",
4
+ "architectures": [
5
+ "Sa2VAChatModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_sa2va_chat.Sa2VAChatConfig",
9
+ "AutoModel": "modeling_sa2va_chat.Sa2VAChatModel",
10
+ "AutoModelForCausalLM": "modeling_sa2va_chat.Sa2VAChatModel"
11
+ },
12
+ "downsample_ratio": 0.5,
13
+ "dynamic_image_size": true,
14
+ "force_image_size": 448,
15
+ "hidden_size": 2048,
16
+ "llm_config": {
17
+ "_attn_implementation_autoset": false,
18
+ "_name_or_path": "Qwen/Qwen2.5-3B-Instruct",
19
+ "add_cross_attention": false,
20
+ "architectures": [
21
+ "Qwen2ForCausalLM"
22
+ ],
23
+ "attention_dropout": 0.0,
24
+ "bad_words_ids": null,
25
+ "begin_suppress_tokens": null,
26
+ "bos_token_id": 151643,
27
+ "chunk_size_feed_forward": 0,
28
+ "cross_attention_hidden_size": null,
29
+ "decoder_start_token_id": null,
30
+ "diversity_penalty": 0.0,
31
+ "do_sample": false,
32
+ "early_stopping": false,
33
+ "encoder_no_repeat_ngram_size": 0,
34
+ "eos_token_id": 151645,
35
+ "exponential_decay_length_penalty": null,
36
+ "finetuning_task": null,
37
+ "forced_bos_token_id": null,
38
+ "forced_eos_token_id": null,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 2048,
41
+ "id2label": {
42
+ "0": "LABEL_0",
43
+ "1": "LABEL_1"
44
+ },
45
+ "initializer_range": 0.02,
46
+ "intermediate_size": 11008,
47
+ "is_decoder": false,
48
+ "is_encoder_decoder": false,
49
+ "label2id": {
50
+ "LABEL_0": 0,
51
+ "LABEL_1": 1
52
+ },
53
+ "length_penalty": 1.0,
54
+ "max_length": 20,
55
+ "max_position_embeddings": 32768,
56
+ "max_window_layers": 70,
57
+ "min_length": 0,
58
+ "model_type": "qwen2",
59
+ "no_repeat_ngram_size": 0,
60
+ "num_attention_heads": 16,
61
+ "num_beam_groups": 1,
62
+ "num_beams": 1,
63
+ "num_hidden_layers": 36,
64
+ "num_key_value_heads": 2,
65
+ "num_return_sequences": 1,
66
+ "output_attentions": false,
67
+ "output_hidden_states": false,
68
+ "output_scores": false,
69
+ "pad_token_id": null,
70
+ "prefix": null,
71
+ "problem_type": null,
72
+ "pruned_heads": {},
73
+ "remove_invalid_values": false,
74
+ "repetition_penalty": 1.0,
75
+ "return_dict": true,
76
+ "return_dict_in_generate": false,
77
+ "rms_norm_eps": 1e-06,
78
+ "rope_scaling": null,
79
+ "rope_theta": 1000000.0,
80
+ "sep_token_id": null,
81
+ "sliding_window": null,
82
+ "suppress_tokens": null,
83
+ "task_specific_params": null,
84
+ "temperature": 1.0,
85
+ "tf_legacy_loss": false,
86
+ "tie_encoder_decoder": false,
87
+ "tie_word_embeddings": false,
88
+ "tokenizer_class": null,
89
+ "top_k": 50,
90
+ "top_p": 1.0,
91
+ "torch_dtype": "bfloat16",
92
+ "torchscript": false,
93
+ "transformers_version": "4.47.0",
94
+ "typical_p": 1.0,
95
+ "use_bfloat16": true,
96
+ "use_cache": true,
97
+ "use_sliding_window": false,
98
+ "vocab_size": 151781
99
+ },
100
+ "m2f_config": {
101
+ "_attn_implementation_autoset": true,
102
+ "_name_or_path": "",
103
+ "activation_function": "relu",
104
+ "add_cross_attention": false,
105
+ "architectures": [
106
+ "Mask2FormerForUniversalSegmentation"
107
+ ],
108
+ "backbone": null,
109
+ "backbone_config": {
110
+ "_attn_implementation_autoset": false,
111
+ "_name_or_path": "",
112
+ "add_cross_attention": false,
113
+ "architectures": [
114
+ "SwinForImageClassification"
115
+ ],
116
+ "attention_probs_dropout_prob": 0.0,
117
+ "bad_words_ids": null,
118
+ "begin_suppress_tokens": null,
119
+ "bos_token_id": null,
120
+ "chunk_size_feed_forward": 0,
121
+ "cross_attention_hidden_size": null,
122
+ "decoder_start_token_id": null,
123
+ "depths": [
124
+ 2,
125
+ 2,
126
+ 18,
127
+ 2
128
+ ],
129
+ "diversity_penalty": 0.0,
130
+ "do_sample": false,
131
+ "drop_path_rate": 0.3,
132
+ "early_stopping": false,
133
+ "embed_dim": 192,
134
+ "encoder_no_repeat_ngram_size": 0,
135
+ "encoder_stride": 32,
136
+ "eos_token_id": null,
137
+ "exponential_decay_length_penalty": null,
138
+ "finetuning_task": null,
139
+ "forced_bos_token_id": null,
140
+ "forced_eos_token_id": null,
141
+ "hidden_act": "gelu",
142
+ "hidden_dropout_prob": 0.0,
143
+ "hidden_size": 1536,
144
+ "id2label": {
145
+ "0": "tench, Tinca tinca",
146
+ "1": "goldfish, Carassius auratus",
147
+ "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
148
+ "3": "tiger shark, Galeocerdo cuvieri",
149
+ "4": "hammerhead, hammerhead shark",
150
+ "5": "electric ray, crampfish, numbfish, torpedo",
151
+ "6": "stingray",
152
+ "7": "cock",
153
+ "8": "hen",
154
+ "9": "ostrich, Struthio camelus",
155
+ "10": "brambling, Fringilla montifringilla",
156
+ "11": "goldfinch, Carduelis carduelis",
157
+ "12": "house finch, linnet, Carpodacus mexicanus",
158
+ "13": "junco, snowbird",
159
+ "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
160
+ "15": "robin, American robin, Turdus migratorius",
161
+ "16": "bulbul",
162
+ "17": "jay",
163
+ "18": "magpie",
164
+ "19": "chickadee",
165
+ "20": "water ouzel, dipper",
166
+ "21": "kite",
167
+ "22": "bald eagle, American eagle, Haliaeetus leucocephalus",
168
+ "23": "vulture",
169
+ "24": "great grey owl, great gray owl, Strix nebulosa",
170
+ "25": "European fire salamander, Salamandra salamandra",
171
+ "26": "common newt, Triturus vulgaris",
172
+ "27": "eft",
173
+ "28": "spotted salamander, Ambystoma maculatum",
174
+ "29": "axolotl, mud puppy, Ambystoma mexicanum",
175
+ "30": "bullfrog, Rana catesbeiana",
176
+ "31": "tree frog, tree-frog",
177
+ "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
178
+ "33": "loggerhead, loggerhead turtle, Caretta caretta",
179
+ "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
180
+ "35": "mud turtle",
181
+ "36": "terrapin",
182
+ "37": "box turtle, box tortoise",
183
+ "38": "banded gecko",
184
+ "39": "common iguana, iguana, Iguana iguana",
185
+ "40": "American chameleon, anole, Anolis carolinensis",
186
+ "41": "whiptail, whiptail lizard",
187
+ "42": "agama",
188
+ "43": "frilled lizard, Chlamydosaurus kingi",
189
+ "44": "alligator lizard",
190
+ "45": "Gila monster, Heloderma suspectum",
191
+ "46": "green lizard, Lacerta viridis",
192
+ "47": "African chameleon, Chamaeleo chamaeleon",
193
+ "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
194
+ "49": "African crocodile, Nile crocodile, Crocodylus niloticus",
195
+ "50": "American alligator, Alligator mississipiensis",
196
+ "51": "triceratops",
197
+ "52": "thunder snake, worm snake, Carphophis amoenus",
198
+ "53": "ringneck snake, ring-necked snake, ring snake",
199
+ "54": "hognose snake, puff adder, sand viper",
200
+ "55": "green snake, grass snake",
201
+ "56": "king snake, kingsnake",
202
+ "57": "garter snake, grass snake",
203
+ "58": "water snake",
204
+ "59": "vine snake",
205
+ "60": "night snake, Hypsiglena torquata",
206
+ "61": "boa constrictor, Constrictor constrictor",
207
+ "62": "rock python, rock snake, Python sebae",
208
+ "63": "Indian cobra, Naja naja",
209
+ "64": "green mamba",
210
+ "65": "sea snake",
211
+ "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
212
+ "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
213
+ "68": "sidewinder, horned rattlesnake, Crotalus cerastes",
214
+ "69": "trilobite",
215
+ "70": "harvestman, daddy longlegs, Phalangium opilio",
216
+ "71": "scorpion",
217
+ "72": "black and gold garden spider, Argiope aurantia",
218
+ "73": "barn spider, Araneus cavaticus",
219
+ "74": "garden spider, Aranea diademata",
220
+ "75": "black widow, Latrodectus mactans",
221
+ "76": "tarantula",
222
+ "77": "wolf spider, hunting spider",
223
+ "78": "tick",
224
+ "79": "centipede",
225
+ "80": "black grouse",
226
+ "81": "ptarmigan",
227
+ "82": "ruffed grouse, partridge, Bonasa umbellus",
228
+ "83": "prairie chicken, prairie grouse, prairie fowl",
229
+ "84": "peacock",
230
+ "85": "quail",
231
+ "86": "partridge",
232
+ "87": "African grey, African gray, Psittacus erithacus",
233
+ "88": "macaw",
234
+ "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
235
+ "90": "lorikeet",
236
+ "91": "coucal",
237
+ "92": "bee eater",
238
+ "93": "hornbill",
239
+ "94": "hummingbird",
240
+ "95": "jacamar",
241
+ "96": "toucan",
242
+ "97": "drake",
243
+ "98": "red-breasted merganser, Mergus serrator",
244
+ "99": "goose",
245
+ "100": "black swan, Cygnus atratus",
246
+ "101": "tusker",
247
+ "102": "echidna, spiny anteater, anteater",
248
+ "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
249
+ "104": "wallaby, brush kangaroo",
250
+ "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
251
+ "106": "wombat",
252
+ "107": "jellyfish",
253
+ "108": "sea anemone, anemone",
254
+ "109": "brain coral",
255
+ "110": "flatworm, platyhelminth",
256
+ "111": "nematode, nematode worm, roundworm",
257
+ "112": "conch",
258
+ "113": "snail",
259
+ "114": "slug",
260
+ "115": "sea slug, nudibranch",
261
+ "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
262
+ "117": "chambered nautilus, pearly nautilus, nautilus",
263
+ "118": "Dungeness crab, Cancer magister",
264
+ "119": "rock crab, Cancer irroratus",
265
+ "120": "fiddler crab",
266
+ "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
267
+ "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
268
+ "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
269
+ "124": "crayfish, crawfish, crawdad, crawdaddy",
270
+ "125": "hermit crab",
271
+ "126": "isopod",
272
+ "127": "white stork, Ciconia ciconia",
273
+ "128": "black stork, Ciconia nigra",
274
+ "129": "spoonbill",
275
+ "130": "flamingo",
276
+ "131": "little blue heron, Egretta caerulea",
277
+ "132": "American egret, great white heron, Egretta albus",
278
+ "133": "bittern",
279
+ "134": "crane",
280
+ "135": "limpkin, Aramus pictus",
281
+ "136": "European gallinule, Porphyrio porphyrio",
282
+ "137": "American coot, marsh hen, mud hen, water hen, Fulica americana",
283
+ "138": "bustard",
284
+ "139": "ruddy turnstone, Arenaria interpres",
285
+ "140": "red-backed sandpiper, dunlin, Erolia alpina",
286
+ "141": "redshank, Tringa totanus",
287
+ "142": "dowitcher",
288
+ "143": "oystercatcher, oyster catcher",
289
+ "144": "pelican",
290
+ "145": "king penguin, Aptenodytes patagonica",
291
+ "146": "albatross, mollymawk",
292
+ "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
293
+ "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
294
+ "149": "dugong, Dugong dugon",
295
+ "150": "sea lion",
296
+ "151": "Chihuahua",
297
+ "152": "Japanese spaniel",
298
+ "153": "Maltese dog, Maltese terrier, Maltese",
299
+ "154": "Pekinese, Pekingese, Peke",
300
+ "155": "Shih-Tzu",
301
+ "156": "Blenheim spaniel",
302
+ "157": "papillon",
303
+ "158": "toy terrier",
304
+ "159": "Rhodesian ridgeback",
305
+ "160": "Afghan hound, Afghan",
306
+ "161": "basset, basset hound",
307
+ "162": "beagle",
308
+ "163": "bloodhound, sleuthhound",
309
+ "164": "bluetick",
310
+ "165": "black-and-tan coonhound",
311
+ "166": "Walker hound, Walker foxhound",
312
+ "167": "English foxhound",
313
+ "168": "redbone",
314
+ "169": "borzoi, Russian wolfhound",
315
+ "170": "Irish wolfhound",
316
+ "171": "Italian greyhound",
317
+ "172": "whippet",
318
+ "173": "Ibizan hound, Ibizan Podenco",
319
+ "174": "Norwegian elkhound, elkhound",
320
+ "175": "otterhound, otter hound",
321
+ "176": "Saluki, gazelle hound",
322
+ "177": "Scottish deerhound, deerhound",
323
+ "178": "Weimaraner",
324
+ "179": "Staffordshire bullterrier, Staffordshire bull terrier",
325
+ "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
326
+ "181": "Bedlington terrier",
327
+ "182": "Border terrier",
328
+ "183": "Kerry blue terrier",
329
+ "184": "Irish terrier",
330
+ "185": "Norfolk terrier",
331
+ "186": "Norwich terrier",
332
+ "187": "Yorkshire terrier",
333
+ "188": "wire-haired fox terrier",
334
+ "189": "Lakeland terrier",
335
+ "190": "Sealyham terrier, Sealyham",
336
+ "191": "Airedale, Airedale terrier",
337
+ "192": "cairn, cairn terrier",
338
+ "193": "Australian terrier",
339
+ "194": "Dandie Dinmont, Dandie Dinmont terrier",
340
+ "195": "Boston bull, Boston terrier",
341
+ "196": "miniature schnauzer",
342
+ "197": "giant schnauzer",
343
+ "198": "standard schnauzer",
344
+ "199": "Scotch terrier, Scottish terrier, Scottie",
345
+ "200": "Tibetan terrier, chrysanthemum dog",
346
+ "201": "silky terrier, Sydney silky",
347
+ "202": "soft-coated wheaten terrier",
348
+ "203": "West Highland white terrier",
349
+ "204": "Lhasa, Lhasa apso",
350
+ "205": "flat-coated retriever",
351
+ "206": "curly-coated retriever",
352
+ "207": "golden retriever",
353
+ "208": "Labrador retriever",
354
+ "209": "Chesapeake Bay retriever",
355
+ "210": "German short-haired pointer",
356
+ "211": "vizsla, Hungarian pointer",
357
+ "212": "English setter",
358
+ "213": "Irish setter, red setter",
359
+ "214": "Gordon setter",
360
+ "215": "Brittany spaniel",
361
+ "216": "clumber, clumber spaniel",
362
+ "217": "English springer, English springer spaniel",
363
+ "218": "Welsh springer spaniel",
364
+ "219": "cocker spaniel, English cocker spaniel, cocker",
365
+ "220": "Sussex spaniel",
366
+ "221": "Irish water spaniel",
367
+ "222": "kuvasz",
368
+ "223": "schipperke",
369
+ "224": "groenendael",
370
+ "225": "malinois",
371
+ "226": "briard",
372
+ "227": "kelpie",
373
+ "228": "komondor",
374
+ "229": "Old English sheepdog, bobtail",
375
+ "230": "Shetland sheepdog, Shetland sheep dog, Shetland",
376
+ "231": "collie",
377
+ "232": "Border collie",
378
+ "233": "Bouvier des Flandres, Bouviers des Flandres",
379
+ "234": "Rottweiler",
380
+ "235": "German shepherd, German shepherd dog, German police dog, alsatian",
381
+ "236": "Doberman, Doberman pinscher",
382
+ "237": "miniature pinscher",
383
+ "238": "Greater Swiss Mountain dog",
384
+ "239": "Bernese mountain dog",
385
+ "240": "Appenzeller",
386
+ "241": "EntleBucher",
387
+ "242": "boxer",
388
+ "243": "bull mastiff",
389
+ "244": "Tibetan mastiff",
390
+ "245": "French bulldog",
391
+ "246": "Great Dane",
392
+ "247": "Saint Bernard, St Bernard",
393
+ "248": "Eskimo dog, husky",
394
+ "249": "malamute, malemute, Alaskan malamute",
395
+ "250": "Siberian husky",
396
+ "251": "dalmatian, coach dog, carriage dog",
397
+ "252": "affenpinscher, monkey pinscher, monkey dog",
398
+ "253": "basenji",
399
+ "254": "pug, pug-dog",
400
+ "255": "Leonberg",
401
+ "256": "Newfoundland, Newfoundland dog",
402
+ "257": "Great Pyrenees",
403
+ "258": "Samoyed, Samoyede",
404
+ "259": "Pomeranian",
405
+ "260": "chow, chow chow",
406
+ "261": "keeshond",
407
+ "262": "Brabancon griffon",
408
+ "263": "Pembroke, Pembroke Welsh corgi",
409
+ "264": "Cardigan, Cardigan Welsh corgi",
410
+ "265": "toy poodle",
411
+ "266": "miniature poodle",
412
+ "267": "standard poodle",
413
+ "268": "Mexican hairless",
414
+ "269": "timber wolf, grey wolf, gray wolf, Canis lupus",
415
+ "270": "white wolf, Arctic wolf, Canis lupus tundrarum",
416
+ "271": "red wolf, maned wolf, Canis rufus, Canis niger",
417
+ "272": "coyote, prairie wolf, brush wolf, Canis latrans",
418
+ "273": "dingo, warrigal, warragal, Canis dingo",
419
+ "274": "dhole, Cuon alpinus",
420
+ "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
421
+ "276": "hyena, hyaena",
422
+ "277": "red fox, Vulpes vulpes",
423
+ "278": "kit fox, Vulpes macrotis",
424
+ "279": "Arctic fox, white fox, Alopex lagopus",
425
+ "280": "grey fox, gray fox, Urocyon cinereoargenteus",
426
+ "281": "tabby, tabby cat",
427
+ "282": "tiger cat",
428
+ "283": "Persian cat",
429
+ "284": "Siamese cat, Siamese",
430
+ "285": "Egyptian cat",
431
+ "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
432
+ "287": "lynx, catamount",
433
+ "288": "leopard, Panthera pardus",
434
+ "289": "snow leopard, ounce, Panthera uncia",
435
+ "290": "jaguar, panther, Panthera onca, Felis onca",
436
+ "291": "lion, king of beasts, Panthera leo",
437
+ "292": "tiger, Panthera tigris",
438
+ "293": "cheetah, chetah, Acinonyx jubatus",
439
+ "294": "brown bear, bruin, Ursus arctos",
440
+ "295": "American black bear, black bear, Ursus americanus, Euarctos americanus",
441
+ "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
442
+ "297": "sloth bear, Melursus ursinus, Ursus ursinus",
443
+ "298": "mongoose",
444
+ "299": "meerkat, mierkat",
445
+ "300": "tiger beetle",
446
+ "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
447
+ "302": "ground beetle, carabid beetle",
448
+ "303": "long-horned beetle, longicorn, longicorn beetle",
449
+ "304": "leaf beetle, chrysomelid",
450
+ "305": "dung beetle",
451
+ "306": "rhinoceros beetle",
452
+ "307": "weevil",
453
+ "308": "fly",
454
+ "309": "bee",
455
+ "310": "ant, emmet, pismire",
456
+ "311": "grasshopper, hopper",
457
+ "312": "cricket",
458
+ "313": "walking stick, walkingstick, stick insect",
459
+ "314": "cockroach, roach",
460
+ "315": "mantis, mantid",
461
+ "316": "cicada, cicala",
462
+ "317": "leafhopper",
463
+ "318": "lacewing, lacewing fly",
464
+ "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
465
+ "320": "damselfly",
466
+ "321": "admiral",
467
+ "322": "ringlet, ringlet butterfly",
468
+ "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
469
+ "324": "cabbage butterfly",
470
+ "325": "sulphur butterfly, sulfur butterfly",
471
+ "326": "lycaenid, lycaenid butterfly",
472
+ "327": "starfish, sea star",
473
+ "328": "sea urchin",
474
+ "329": "sea cucumber, holothurian",
475
+ "330": "wood rabbit, cottontail, cottontail rabbit",
476
+ "331": "hare",
477
+ "332": "Angora, Angora rabbit",
478
+ "333": "hamster",
479
+ "334": "porcupine, hedgehog",
480
+ "335": "fox squirrel, eastern fox squirrel, Sciurus niger",
481
+ "336": "marmot",
482
+ "337": "beaver",
483
+ "338": "guinea pig, Cavia cobaya",
484
+ "339": "sorrel",
485
+ "340": "zebra",
486
+ "341": "hog, pig, grunter, squealer, Sus scrofa",
487
+ "342": "wild boar, boar, Sus scrofa",
488
+ "343": "warthog",
489
+ "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
490
+ "345": "ox",
491
+ "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
492
+ "347": "bison",
493
+ "348": "ram, tup",
494
+ "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
495
+ "350": "ibex, Capra ibex",
496
+ "351": "hartebeest",
497
+ "352": "impala, Aepyceros melampus",
498
+ "353": "gazelle",
499
+ "354": "Arabian camel, dromedary, Camelus dromedarius",
500
+ "355": "llama",
501
+ "356": "weasel",
502
+ "357": "mink",
503
+ "358": "polecat, fitch, foulmart, foumart, Mustela putorius",
504
+ "359": "black-footed ferret, ferret, Mustela nigripes",
505
+ "360": "otter",
506
+ "361": "skunk, polecat, wood pussy",
507
+ "362": "badger",
508
+ "363": "armadillo",
509
+ "364": "three-toed sloth, ai, Bradypus tridactylus",
510
+ "365": "orangutan, orang, orangutang, Pongo pygmaeus",
511
+ "366": "gorilla, Gorilla gorilla",
512
+ "367": "chimpanzee, chimp, Pan troglodytes",
513
+ "368": "gibbon, Hylobates lar",
514
+ "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
515
+ "370": "guenon, guenon monkey",
516
+ "371": "patas, hussar monkey, Erythrocebus patas",
517
+ "372": "baboon",
518
+ "373": "macaque",
519
+ "374": "langur",
520
+ "375": "colobus, colobus monkey",
521
+ "376": "proboscis monkey, Nasalis larvatus",
522
+ "377": "marmoset",
523
+ "378": "capuchin, ringtail, Cebus capucinus",
524
+ "379": "howler monkey, howler",
525
+ "380": "titi, titi monkey",
526
+ "381": "spider monkey, Ateles geoffroyi",
527
+ "382": "squirrel monkey, Saimiri sciureus",
528
+ "383": "Madagascar cat, ring-tailed lemur, Lemur catta",
529
+ "384": "indri, indris, Indri indri, Indri brevicaudatus",
530
+ "385": "Indian elephant, Elephas maximus",
531
+ "386": "African elephant, Loxodonta africana",
532
+ "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
533
+ "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
534
+ "389": "barracouta, snoek",
535
+ "390": "eel",
536
+ "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
537
+ "392": "rock beauty, Holocanthus tricolor",
538
+ "393": "anemone fish",
539
+ "394": "sturgeon",
540
+ "395": "gar, garfish, garpike, billfish, Lepisosteus osseus",
541
+ "396": "lionfish",
542
+ "397": "puffer, pufferfish, blowfish, globefish",
543
+ "398": "abacus",
544
+ "399": "abaya",
545
+ "400": "academic gown, academic robe, judge's robe",
546
+ "401": "accordion, piano accordion, squeeze box",
547
+ "402": "acoustic guitar",
548
+ "403": "aircraft carrier, carrier, flattop, attack aircraft carrier",
549
+ "404": "airliner",
550
+ "405": "airship, dirigible",
551
+ "406": "altar",
552
+ "407": "ambulance",
553
+ "408": "amphibian, amphibious vehicle",
554
+ "409": "analog clock",
555
+ "410": "apiary, bee house",
556
+ "411": "apron",
557
+ "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
558
+ "413": "assault rifle, assault gun",
559
+ "414": "backpack, back pack, knapsack, packsack, rucksack, haversack",
560
+ "415": "bakery, bakeshop, bakehouse",
561
+ "416": "balance beam, beam",
562
+ "417": "balloon",
563
+ "418": "ballpoint, ballpoint pen, ballpen, Biro",
564
+ "419": "Band Aid",
565
+ "420": "banjo",
566
+ "421": "bannister, banister, balustrade, balusters, handrail",
567
+ "422": "barbell",
568
+ "423": "barber chair",
569
+ "424": "barbershop",
570
+ "425": "barn",
571
+ "426": "barometer",
572
+ "427": "barrel, cask",
573
+ "428": "barrow, garden cart, lawn cart, wheelbarrow",
574
+ "429": "baseball",
575
+ "430": "basketball",
576
+ "431": "bassinet",
577
+ "432": "bassoon",
578
+ "433": "bathing cap, swimming cap",
579
+ "434": "bath towel",
580
+ "435": "bathtub, bathing tub, bath, tub",
581
+ "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
582
+ "437": "beacon, lighthouse, beacon light, pharos",
583
+ "438": "beaker",
584
+ "439": "bearskin, busby, shako",
585
+ "440": "beer bottle",
586
+ "441": "beer glass",
587
+ "442": "bell cote, bell cot",
588
+ "443": "bib",
589
+ "444": "bicycle-built-for-two, tandem bicycle, tandem",
590
+ "445": "bikini, two-piece",
591
+ "446": "binder, ring-binder",
592
+ "447": "binoculars, field glasses, opera glasses",
593
+ "448": "birdhouse",
594
+ "449": "boathouse",
595
+ "450": "bobsled, bobsleigh, bob",
596
+ "451": "bolo tie, bolo, bola tie, bola",
597
+ "452": "bonnet, poke bonnet",
598
+ "453": "bookcase",
599
+ "454": "bookshop, bookstore, bookstall",
600
+ "455": "bottlecap",
601
+ "456": "bow",
602
+ "457": "bow tie, bow-tie, bowtie",
603
+ "458": "brass, memorial tablet, plaque",
604
+ "459": "brassiere, bra, bandeau",
605
+ "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
606
+ "461": "breastplate, aegis, egis",
607
+ "462": "broom",
608
+ "463": "bucket, pail",
609
+ "464": "buckle",
610
+ "465": "bulletproof vest",
611
+ "466": "bullet train, bullet",
612
+ "467": "butcher shop, meat market",
613
+ "468": "cab, hack, taxi, taxicab",
614
+ "469": "caldron, cauldron",
615
+ "470": "candle, taper, wax light",
616
+ "471": "cannon",
617
+ "472": "canoe",
618
+ "473": "can opener, tin opener",
619
+ "474": "cardigan",
620
+ "475": "car mirror",
621
+ "476": "carousel, carrousel, merry-go-round, roundabout, whirligig",
622
+ "477": "carpenter's kit, tool kit",
623
+ "478": "carton",
624
+ "479": "car wheel",
625
+ "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
626
+ "481": "cassette",
627
+ "482": "cassette player",
628
+ "483": "castle",
629
+ "484": "catamaran",
630
+ "485": "CD player",
631
+ "486": "cello, violoncello",
632
+ "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
633
+ "488": "chain",
634
+ "489": "chainlink fence",
635
+ "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
636
+ "491": "chain saw, chainsaw",
637
+ "492": "chest",
638
+ "493": "chiffonier, commode",
639
+ "494": "chime, bell, gong",
640
+ "495": "china cabinet, china closet",
641
+ "496": "Christmas stocking",
642
+ "497": "church, church building",
643
+ "498": "cinema, movie theater, movie theatre, movie house, picture palace",
644
+ "499": "cleaver, meat cleaver, chopper",
645
+ "500": "cliff dwelling",
646
+ "501": "cloak",
647
+ "502": "clog, geta, patten, sabot",
648
+ "503": "cocktail shaker",
649
+ "504": "coffee mug",
650
+ "505": "coffeepot",
651
+ "506": "coil, spiral, volute, whorl, helix",
652
+ "507": "combination lock",
653
+ "508": "computer keyboard, keypad",
654
+ "509": "confectionery, confectionary, candy store",
655
+ "510": "container ship, containership, container vessel",
656
+ "511": "convertible",
657
+ "512": "corkscrew, bottle screw",
658
+ "513": "cornet, horn, trumpet, trump",
659
+ "514": "cowboy boot",
660
+ "515": "cowboy hat, ten-gallon hat",
661
+ "516": "cradle",
662
+ "517": "crane",
663
+ "518": "crash helmet",
664
+ "519": "crate",
665
+ "520": "crib, cot",
666
+ "521": "Crock Pot",
667
+ "522": "croquet ball",
668
+ "523": "crutch",
669
+ "524": "cuirass",
670
+ "525": "dam, dike, dyke",
671
+ "526": "desk",
672
+ "527": "desktop computer",
673
+ "528": "dial telephone, dial phone",
674
+ "529": "diaper, nappy, napkin",
675
+ "530": "digital clock",
676
+ "531": "digital watch",
677
+ "532": "dining table, board",
678
+ "533": "dishrag, dishcloth",
679
+ "534": "dishwasher, dish washer, dishwashing machine",
680
+ "535": "disk brake, disc brake",
681
+ "536": "dock, dockage, docking facility",
682
+ "537": "dogsled, dog sled, dog sleigh",
683
+ "538": "dome",
684
+ "539": "doormat, welcome mat",
685
+ "540": "drilling platform, offshore rig",
686
+ "541": "drum, membranophone, tympan",
687
+ "542": "drumstick",
688
+ "543": "dumbbell",
689
+ "544": "Dutch oven",
690
+ "545": "electric fan, blower",
691
+ "546": "electric guitar",
692
+ "547": "electric locomotive",
693
+ "548": "entertainment center",
694
+ "549": "envelope",
695
+ "550": "espresso maker",
696
+ "551": "face powder",
697
+ "552": "feather boa, boa",
698
+ "553": "file, file cabinet, filing cabinet",
699
+ "554": "fireboat",
700
+ "555": "fire engine, fire truck",
701
+ "556": "fire screen, fireguard",
702
+ "557": "flagpole, flagstaff",
703
+ "558": "flute, transverse flute",
704
+ "559": "folding chair",
705
+ "560": "football helmet",
706
+ "561": "forklift",
707
+ "562": "fountain",
708
+ "563": "fountain pen",
709
+ "564": "four-poster",
710
+ "565": "freight car",
711
+ "566": "French horn, horn",
712
+ "567": "frying pan, frypan, skillet",
713
+ "568": "fur coat",
714
+ "569": "garbage truck, dustcart",
715
+ "570": "gasmask, respirator, gas helmet",
716
+ "571": "gas pump, gasoline pump, petrol pump, island dispenser",
717
+ "572": "goblet",
718
+ "573": "go-kart",
719
+ "574": "golf ball",
720
+ "575": "golfcart, golf cart",
721
+ "576": "gondola",
722
+ "577": "gong, tam-tam",
723
+ "578": "gown",
724
+ "579": "grand piano, grand",
725
+ "580": "greenhouse, nursery, glasshouse",
726
+ "581": "grille, radiator grille",
727
+ "582": "grocery store, grocery, food market, market",
728
+ "583": "guillotine",
729
+ "584": "hair slide",
730
+ "585": "hair spray",
731
+ "586": "half track",
732
+ "587": "hammer",
733
+ "588": "hamper",
734
+ "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
735
+ "590": "hand-held computer, hand-held microcomputer",
736
+ "591": "handkerchief, hankie, hanky, hankey",
737
+ "592": "hard disc, hard disk, fixed disk",
738
+ "593": "harmonica, mouth organ, harp, mouth harp",
739
+ "594": "harp",
740
+ "595": "harvester, reaper",
741
+ "596": "hatchet",
742
+ "597": "holster",
743
+ "598": "home theater, home theatre",
744
+ "599": "honeycomb",
745
+ "600": "hook, claw",
746
+ "601": "hoopskirt, crinoline",
747
+ "602": "horizontal bar, high bar",
748
+ "603": "horse cart, horse-cart",
749
+ "604": "hourglass",
750
+ "605": "iPod",
751
+ "606": "iron, smoothing iron",
752
+ "607": "jack-o'-lantern",
753
+ "608": "jean, blue jean, denim",
754
+ "609": "jeep, landrover",
755
+ "610": "jersey, T-shirt, tee shirt",
756
+ "611": "jigsaw puzzle",
757
+ "612": "jinrikisha, ricksha, rickshaw",
758
+ "613": "joystick",
759
+ "614": "kimono",
760
+ "615": "knee pad",
761
+ "616": "knot",
762
+ "617": "lab coat, laboratory coat",
763
+ "618": "ladle",
764
+ "619": "lampshade, lamp shade",
765
+ "620": "laptop, laptop computer",
766
+ "621": "lawn mower, mower",
767
+ "622": "lens cap, lens cover",
768
+ "623": "letter opener, paper knife, paperknife",
769
+ "624": "library",
770
+ "625": "lifeboat",
771
+ "626": "lighter, light, igniter, ignitor",
772
+ "627": "limousine, limo",
773
+ "628": "liner, ocean liner",
774
+ "629": "lipstick, lip rouge",
775
+ "630": "Loafer",
776
+ "631": "lotion",
777
+ "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
778
+ "633": "loupe, jeweler's loupe",
779
+ "634": "lumbermill, sawmill",
780
+ "635": "magnetic compass",
781
+ "636": "mailbag, postbag",
782
+ "637": "mailbox, letter box",
783
+ "638": "maillot",
784
+ "639": "maillot, tank suit",
785
+ "640": "manhole cover",
786
+ "641": "maraca",
787
+ "642": "marimba, xylophone",
788
+ "643": "mask",
789
+ "644": "matchstick",
790
+ "645": "maypole",
791
+ "646": "maze, labyrinth",
792
+ "647": "measuring cup",
793
+ "648": "medicine chest, medicine cabinet",
794
+ "649": "megalith, megalithic structure",
795
+ "650": "microphone, mike",
796
+ "651": "microwave, microwave oven",
797
+ "652": "military uniform",
798
+ "653": "milk can",
799
+ "654": "minibus",
800
+ "655": "miniskirt, mini",
801
+ "656": "minivan",
802
+ "657": "missile",
803
+ "658": "mitten",
804
+ "659": "mixing bowl",
805
+ "660": "mobile home, manufactured home",
806
+ "661": "Model T",
807
+ "662": "modem",
808
+ "663": "monastery",
809
+ "664": "monitor",
810
+ "665": "moped",
811
+ "666": "mortar",
812
+ "667": "mortarboard",
813
+ "668": "mosque",
814
+ "669": "mosquito net",
815
+ "670": "motor scooter, scooter",
816
+ "671": "mountain bike, all-terrain bike, off-roader",
817
+ "672": "mountain tent",
818
+ "673": "mouse, computer mouse",
819
+ "674": "mousetrap",
820
+ "675": "moving van",
821
+ "676": "muzzle",
822
+ "677": "nail",
823
+ "678": "neck brace",
824
+ "679": "necklace",
825
+ "680": "nipple",
826
+ "681": "notebook, notebook computer",
827
+ "682": "obelisk",
828
+ "683": "oboe, hautboy, hautbois",
829
+ "684": "ocarina, sweet potato",
830
+ "685": "odometer, hodometer, mileometer, milometer",
831
+ "686": "oil filter",
832
+ "687": "organ, pipe organ",
833
+ "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
834
+ "689": "overskirt",
835
+ "690": "oxcart",
836
+ "691": "oxygen mask",
837
+ "692": "packet",
838
+ "693": "paddle, boat paddle",
839
+ "694": "paddlewheel, paddle wheel",
840
+ "695": "padlock",
841
+ "696": "paintbrush",
842
+ "697": "pajama, pyjama, pj's, jammies",
843
+ "698": "palace",
844
+ "699": "panpipe, pandean pipe, syrinx",
845
+ "700": "paper towel",
846
+ "701": "parachute, chute",
847
+ "702": "parallel bars, bars",
848
+ "703": "park bench",
849
+ "704": "parking meter",
850
+ "705": "passenger car, coach, carriage",
851
+ "706": "patio, terrace",
852
+ "707": "pay-phone, pay-station",
853
+ "708": "pedestal, plinth, footstall",
854
+ "709": "pencil box, pencil case",
855
+ "710": "pencil sharpener",
856
+ "711": "perfume, essence",
857
+ "712": "Petri dish",
858
+ "713": "photocopier",
859
+ "714": "pick, plectrum, plectron",
860
+ "715": "pickelhaube",
861
+ "716": "picket fence, paling",
862
+ "717": "pickup, pickup truck",
863
+ "718": "pier",
864
+ "719": "piggy bank, penny bank",
865
+ "720": "pill bottle",
866
+ "721": "pillow",
867
+ "722": "ping-pong ball",
868
+ "723": "pinwheel",
869
+ "724": "pirate, pirate ship",
870
+ "725": "pitcher, ewer",
871
+ "726": "plane, carpenter's plane, woodworking plane",
872
+ "727": "planetarium",
873
+ "728": "plastic bag",
874
+ "729": "plate rack",
875
+ "730": "plow, plough",
876
+ "731": "plunger, plumber's helper",
877
+ "732": "Polaroid camera, Polaroid Land camera",
878
+ "733": "pole",
879
+ "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
880
+ "735": "poncho",
881
+ "736": "pool table, billiard table, snooker table",
882
+ "737": "pop bottle, soda bottle",
883
+ "738": "pot, flowerpot",
884
+ "739": "potter's wheel",
885
+ "740": "power drill",
886
+ "741": "prayer rug, prayer mat",
887
+ "742": "printer",
888
+ "743": "prison, prison house",
889
+ "744": "projectile, missile",
890
+ "745": "projector",
891
+ "746": "puck, hockey puck",
892
+ "747": "punching bag, punch bag, punching ball, punchball",
893
+ "748": "purse",
894
+ "749": "quill, quill pen",
895
+ "750": "quilt, comforter, comfort, puff",
896
+ "751": "racer, race car, racing car",
897
+ "752": "racket, racquet",
898
+ "753": "radiator",
899
+ "754": "radio, wireless",
900
+ "755": "radio telescope, radio reflector",
901
+ "756": "rain barrel",
902
+ "757": "recreational vehicle, RV, R.V.",
903
+ "758": "reel",
904
+ "759": "reflex camera",
905
+ "760": "refrigerator, icebox",
906
+ "761": "remote control, remote",
907
+ "762": "restaurant, eating house, eating place, eatery",
908
+ "763": "revolver, six-gun, six-shooter",
909
+ "764": "rifle",
910
+ "765": "rocking chair, rocker",
911
+ "766": "rotisserie",
912
+ "767": "rubber eraser, rubber, pencil eraser",
913
+ "768": "rugby ball",
914
+ "769": "rule, ruler",
915
+ "770": "running shoe",
916
+ "771": "safe",
917
+ "772": "safety pin",
918
+ "773": "saltshaker, salt shaker",
919
+ "774": "sandal",
920
+ "775": "sarong",
921
+ "776": "sax, saxophone",
922
+ "777": "scabbard",
923
+ "778": "scale, weighing machine",
924
+ "779": "school bus",
925
+ "780": "schooner",
926
+ "781": "scoreboard",
927
+ "782": "screen, CRT screen",
928
+ "783": "screw",
929
+ "784": "screwdriver",
930
+ "785": "seat belt, seatbelt",
931
+ "786": "sewing machine",
932
+ "787": "shield, buckler",
933
+ "788": "shoe shop, shoe-shop, shoe store",
934
+ "789": "shoji",
935
+ "790": "shopping basket",
936
+ "791": "shopping cart",
937
+ "792": "shovel",
938
+ "793": "shower cap",
939
+ "794": "shower curtain",
940
+ "795": "ski",
941
+ "796": "ski mask",
942
+ "797": "sleeping bag",
943
+ "798": "slide rule, slipstick",
944
+ "799": "sliding door",
945
+ "800": "slot, one-armed bandit",
946
+ "801": "snorkel",
947
+ "802": "snowmobile",
948
+ "803": "snowplow, snowplough",
949
+ "804": "soap dispenser",
950
+ "805": "soccer ball",
951
+ "806": "sock",
952
+ "807": "solar dish, solar collector, solar furnace",
953
+ "808": "sombrero",
954
+ "809": "soup bowl",
955
+ "810": "space bar",
956
+ "811": "space heater",
957
+ "812": "space shuttle",
958
+ "813": "spatula",
959
+ "814": "speedboat",
960
+ "815": "spider web, spider's web",
961
+ "816": "spindle",
962
+ "817": "sports car, sport car",
963
+ "818": "spotlight, spot",
964
+ "819": "stage",
965
+ "820": "steam locomotive",
966
+ "821": "steel arch bridge",
967
+ "822": "steel drum",
968
+ "823": "stethoscope",
969
+ "824": "stole",
970
+ "825": "stone wall",
971
+ "826": "stopwatch, stop watch",
972
+ "827": "stove",
973
+ "828": "strainer",
974
+ "829": "streetcar, tram, tramcar, trolley, trolley car",
975
+ "830": "stretcher",
976
+ "831": "studio couch, day bed",
977
+ "832": "stupa, tope",
978
+ "833": "submarine, pigboat, sub, U-boat",
979
+ "834": "suit, suit of clothes",
980
+ "835": "sundial",
981
+ "836": "sunglass",
982
+ "837": "sunglasses, dark glasses, shades",
983
+ "838": "sunscreen, sunblock, sun blocker",
984
+ "839": "suspension bridge",
985
+ "840": "swab, swob, mop",
986
+ "841": "sweatshirt",
987
+ "842": "swimming trunks, bathing trunks",
988
+ "843": "swing",
989
+ "844": "switch, electric switch, electrical switch",
990
+ "845": "syringe",
991
+ "846": "table lamp",
992
+ "847": "tank, army tank, armored combat vehicle, armoured combat vehicle",
993
+ "848": "tape player",
994
+ "849": "teapot",
995
+ "850": "teddy, teddy bear",
996
+ "851": "television, television system",
997
+ "852": "tennis ball",
998
+ "853": "thatch, thatched roof",
999
+ "854": "theater curtain, theatre curtain",
1000
+ "855": "thimble",
1001
+ "856": "thresher, thrasher, threshing machine",
1002
+ "857": "throne",
1003
+ "858": "tile roof",
1004
+ "859": "toaster",
1005
+ "860": "tobacco shop, tobacconist shop, tobacconist",
1006
+ "861": "toilet seat",
1007
+ "862": "torch",
1008
+ "863": "totem pole",
1009
+ "864": "tow truck, tow car, wrecker",
1010
+ "865": "toyshop",
1011
+ "866": "tractor",
1012
+ "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
1013
+ "868": "tray",
1014
+ "869": "trench coat",
1015
+ "870": "tricycle, trike, velocipede",
1016
+ "871": "trimaran",
1017
+ "872": "tripod",
1018
+ "873": "triumphal arch",
1019
+ "874": "trolleybus, trolley coach, trackless trolley",
1020
+ "875": "trombone",
1021
+ "876": "tub, vat",
1022
+ "877": "turnstile",
1023
+ "878": "typewriter keyboard",
1024
+ "879": "umbrella",
1025
+ "880": "unicycle, monocycle",
1026
+ "881": "upright, upright piano",
1027
+ "882": "vacuum, vacuum cleaner",
1028
+ "883": "vase",
1029
+ "884": "vault",
1030
+ "885": "velvet",
1031
+ "886": "vending machine",
1032
+ "887": "vestment",
1033
+ "888": "viaduct",
1034
+ "889": "violin, fiddle",
1035
+ "890": "volleyball",
1036
+ "891": "waffle iron",
1037
+ "892": "wall clock",
1038
+ "893": "wallet, billfold, notecase, pocketbook",
1039
+ "894": "wardrobe, closet, press",
1040
+ "895": "warplane, military plane",
1041
+ "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
1042
+ "897": "washer, automatic washer, washing machine",
1043
+ "898": "water bottle",
1044
+ "899": "water jug",
1045
+ "900": "water tower",
1046
+ "901": "whiskey jug",
1047
+ "902": "whistle",
1048
+ "903": "wig",
1049
+ "904": "window screen",
1050
+ "905": "window shade",
1051
+ "906": "Windsor tie",
1052
+ "907": "wine bottle",
1053
+ "908": "wing",
1054
+ "909": "wok",
1055
+ "910": "wooden spoon",
1056
+ "911": "wool, woolen, woollen",
1057
+ "912": "worm fence, snake fence, snake-rail fence, Virginia fence",
1058
+ "913": "wreck",
1059
+ "914": "yawl",
1060
+ "915": "yurt",
1061
+ "916": "web site, website, internet site, site",
1062
+ "917": "comic book",
1063
+ "918": "crossword puzzle, crossword",
1064
+ "919": "street sign",
1065
+ "920": "traffic light, traffic signal, stoplight",
1066
+ "921": "book jacket, dust cover, dust jacket, dust wrapper",
1067
+ "922": "menu",
1068
+ "923": "plate",
1069
+ "924": "guacamole",
1070
+ "925": "consomme",
1071
+ "926": "hot pot, hotpot",
1072
+ "927": "trifle",
1073
+ "928": "ice cream, icecream",
1074
+ "929": "ice lolly, lolly, lollipop, popsicle",
1075
+ "930": "French loaf",
1076
+ "931": "bagel, beigel",
1077
+ "932": "pretzel",
1078
+ "933": "cheeseburger",
1079
+ "934": "hotdog, hot dog, red hot",
1080
+ "935": "mashed potato",
1081
+ "936": "head cabbage",
1082
+ "937": "broccoli",
1083
+ "938": "cauliflower",
1084
+ "939": "zucchini, courgette",
1085
+ "940": "spaghetti squash",
1086
+ "941": "acorn squash",
1087
+ "942": "butternut squash",
1088
+ "943": "cucumber, cuke",
1089
+ "944": "artichoke, globe artichoke",
1090
+ "945": "bell pepper",
1091
+ "946": "cardoon",
1092
+ "947": "mushroom",
1093
+ "948": "Granny Smith",
1094
+ "949": "strawberry",
1095
+ "950": "orange",
1096
+ "951": "lemon",
1097
+ "952": "fig",
1098
+ "953": "pineapple, ananas",
1099
+ "954": "banana",
1100
+ "955": "jackfruit, jak, jack",
1101
+ "956": "custard apple",
1102
+ "957": "pomegranate",
1103
+ "958": "hay",
1104
+ "959": "carbonara",
1105
+ "960": "chocolate sauce, chocolate syrup",
1106
+ "961": "dough",
1107
+ "962": "meat loaf, meatloaf",
1108
+ "963": "pizza, pizza pie",
1109
+ "964": "potpie",
1110
+ "965": "burrito",
1111
+ "966": "red wine",
1112
+ "967": "espresso",
1113
+ "968": "cup",
1114
+ "969": "eggnog",
1115
+ "970": "alp",
1116
+ "971": "bubble",
1117
+ "972": "cliff, drop, drop-off",
1118
+ "973": "coral reef",
1119
+ "974": "geyser",
1120
+ "975": "lakeside, lakeshore",
1121
+ "976": "promontory, headland, head, foreland",
1122
+ "977": "sandbar, sand bar",
1123
+ "978": "seashore, coast, seacoast, sea-coast",
1124
+ "979": "valley, vale",
1125
+ "980": "volcano",
1126
+ "981": "ballplayer, baseball player",
1127
+ "982": "groom, bridegroom",
1128
+ "983": "scuba diver",
1129
+ "984": "rapeseed",
1130
+ "985": "daisy",
1131
+ "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
1132
+ "987": "corn",
1133
+ "988": "acorn",
1134
+ "989": "hip, rose hip, rosehip",
1135
+ "990": "buckeye, horse chestnut, conker",
1136
+ "991": "coral fungus",
1137
+ "992": "agaric",
1138
+ "993": "gyromitra",
1139
+ "994": "stinkhorn, carrion fungus",
1140
+ "995": "earthstar",
1141
+ "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1142
+ "997": "bolete",
1143
+ "998": "ear, spike, capitulum",
1144
+ "999": "toilet tissue, toilet paper, bathroom tissue"
1145
+ },
1146
+ "image_size": 384,
1147
+ "initializer_range": 0.02,
1148
+ "is_decoder": false,
1149
+ "is_encoder_decoder": false,
1150
+ "label2id": {
1151
+ "Afghan hound, Afghan": 160,
1152
+ "African chameleon, Chamaeleo chamaeleon": 47,
1153
+ "African crocodile, Nile crocodile, Crocodylus niloticus": 49,
1154
+ "African elephant, Loxodonta africana": 386,
1155
+ "African grey, African gray, Psittacus erithacus": 87,
1156
+ "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus": 275,
1157
+ "Airedale, Airedale terrier": 191,
1158
+ "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier": 180,
1159
+ "American alligator, Alligator mississipiensis": 50,
1160
+ "American black bear, black bear, Ursus americanus, Euarctos americanus": 295,
1161
+ "American chameleon, anole, Anolis carolinensis": 40,
1162
+ "American coot, marsh hen, mud hen, water hen, Fulica americana": 137,
1163
+ "American egret, great white heron, Egretta albus": 132,
1164
+ "American lobster, Northern lobster, Maine lobster, Homarus americanus": 122,
1165
+ "Angora, Angora rabbit": 332,
1166
+ "Appenzeller": 240,
1167
+ "Arabian camel, dromedary, Camelus dromedarius": 354,
1168
+ "Arctic fox, white fox, Alopex lagopus": 279,
1169
+ "Australian terrier": 193,
1170
+ "Band Aid": 419,
1171
+ "Bedlington terrier": 181,
1172
+ "Bernese mountain dog": 239,
1173
+ "Blenheim spaniel": 156,
1174
+ "Border collie": 232,
1175
+ "Border terrier": 182,
1176
+ "Boston bull, Boston terrier": 195,
1177
+ "Bouvier des Flandres, Bouviers des Flandres": 233,
1178
+ "Brabancon griffon": 262,
1179
+ "Brittany spaniel": 215,
1180
+ "CD player": 485,
1181
+ "Cardigan, Cardigan Welsh corgi": 264,
1182
+ "Chesapeake Bay retriever": 209,
1183
+ "Chihuahua": 151,
1184
+ "Christmas stocking": 496,
1185
+ "Crock Pot": 521,
1186
+ "Dandie Dinmont, Dandie Dinmont terrier": 194,
1187
+ "Doberman, Doberman pinscher": 236,
1188
+ "Dungeness crab, Cancer magister": 118,
1189
+ "Dutch oven": 544,
1190
+ "Egyptian cat": 285,
1191
+ "English foxhound": 167,
1192
+ "English setter": 212,
1193
+ "English springer, English springer spaniel": 217,
1194
+ "EntleBucher": 241,
1195
+ "Eskimo dog, husky": 248,
1196
+ "European fire salamander, Salamandra salamandra": 25,
1197
+ "European gallinule, Porphyrio porphyrio": 136,
1198
+ "French bulldog": 245,
1199
+ "French horn, horn": 566,
1200
+ "French loaf": 930,
1201
+ "German shepherd, German shepherd dog, German police dog, alsatian": 235,
1202
+ "German short-haired pointer": 210,
1203
+ "Gila monster, Heloderma suspectum": 45,
1204
+ "Gordon setter": 214,
1205
+ "Granny Smith": 948,
1206
+ "Great Dane": 246,
1207
+ "Great Pyrenees": 257,
1208
+ "Greater Swiss Mountain dog": 238,
1209
+ "Ibizan hound, Ibizan Podenco": 173,
1210
+ "Indian cobra, Naja naja": 63,
1211
+ "Indian elephant, Elephas maximus": 385,
1212
+ "Irish setter, red setter": 213,
1213
+ "Irish terrier": 184,
1214
+ "Irish water spaniel": 221,
1215
+ "Irish wolfhound": 170,
1216
+ "Italian greyhound": 171,
1217
+ "Japanese spaniel": 152,
1218
+ "Kerry blue terrier": 183,
1219
+ "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis": 48,
1220
+ "Labrador retriever": 208,
1221
+ "Lakeland terrier": 189,
1222
+ "Leonberg": 255,
1223
+ "Lhasa, Lhasa apso": 204,
1224
+ "Loafer": 630,
1225
+ "Madagascar cat, ring-tailed lemur, Lemur catta": 383,
1226
+ "Maltese dog, Maltese terrier, Maltese": 153,
1227
+ "Mexican hairless": 268,
1228
+ "Model T": 661,
1229
+ "Newfoundland, Newfoundland dog": 256,
1230
+ "Norfolk terrier": 185,
1231
+ "Norwegian elkhound, elkhound": 174,
1232
+ "Norwich terrier": 186,
1233
+ "Old English sheepdog, bobtail": 229,
1234
+ "Pekinese, Pekingese, Peke": 154,
1235
+ "Pembroke, Pembroke Welsh corgi": 263,
1236
+ "Persian cat": 283,
1237
+ "Petri dish": 712,
1238
+ "Polaroid camera, Polaroid Land camera": 732,
1239
+ "Pomeranian": 259,
1240
+ "Rhodesian ridgeback": 159,
1241
+ "Rottweiler": 234,
1242
+ "Saint Bernard, St Bernard": 247,
1243
+ "Saluki, gazelle hound": 176,
1244
+ "Samoyed, Samoyede": 258,
1245
+ "Scotch terrier, Scottish terrier, Scottie": 199,
1246
+ "Scottish deerhound, deerhound": 177,
1247
+ "Sealyham terrier, Sealyham": 190,
1248
+ "Shetland sheepdog, Shetland sheep dog, Shetland": 230,
1249
+ "Shih-Tzu": 155,
1250
+ "Siamese cat, Siamese": 284,
1251
+ "Siberian husky": 250,
1252
+ "Staffordshire bullterrier, Staffordshire bull terrier": 179,
1253
+ "Sussex spaniel": 220,
1254
+ "Tibetan mastiff": 244,
1255
+ "Tibetan terrier, chrysanthemum dog": 200,
1256
+ "Walker hound, Walker foxhound": 166,
1257
+ "Weimaraner": 178,
1258
+ "Welsh springer spaniel": 218,
1259
+ "West Highland white terrier": 203,
1260
+ "Windsor tie": 906,
1261
+ "Yorkshire terrier": 187,
1262
+ "abacus": 398,
1263
+ "abaya": 399,
1264
+ "academic gown, academic robe, judge's robe": 400,
1265
+ "accordion, piano accordion, squeeze box": 401,
1266
+ "acorn": 988,
1267
+ "acorn squash": 941,
1268
+ "acoustic guitar": 402,
1269
+ "admiral": 321,
1270
+ "affenpinscher, monkey pinscher, monkey dog": 252,
1271
+ "agama": 42,
1272
+ "agaric": 992,
1273
+ "aircraft carrier, carrier, flattop, attack aircraft carrier": 403,
1274
+ "airliner": 404,
1275
+ "airship, dirigible": 405,
1276
+ "albatross, mollymawk": 146,
1277
+ "alligator lizard": 44,
1278
+ "alp": 970,
1279
+ "altar": 406,
1280
+ "ambulance": 407,
1281
+ "amphibian, amphibious vehicle": 408,
1282
+ "analog clock": 409,
1283
+ "anemone fish": 393,
1284
+ "ant, emmet, pismire": 310,
1285
+ "apiary, bee house": 410,
1286
+ "apron": 411,
1287
+ "armadillo": 363,
1288
+ "artichoke, globe artichoke": 944,
1289
+ "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin": 412,
1290
+ "assault rifle, assault gun": 413,
1291
+ "axolotl, mud puppy, Ambystoma mexicanum": 29,
1292
+ "baboon": 372,
1293
+ "backpack, back pack, knapsack, packsack, rucksack, haversack": 414,
1294
+ "badger": 362,
1295
+ "bagel, beigel": 931,
1296
+ "bakery, bakeshop, bakehouse": 415,
1297
+ "balance beam, beam": 416,
1298
+ "bald eagle, American eagle, Haliaeetus leucocephalus": 22,
1299
+ "balloon": 417,
1300
+ "ballplayer, baseball player": 981,
1301
+ "ballpoint, ballpoint pen, ballpen, Biro": 418,
1302
+ "banana": 954,
1303
+ "banded gecko": 38,
1304
+ "banjo": 420,
1305
+ "bannister, banister, balustrade, balusters, handrail": 421,
1306
+ "barbell": 422,
1307
+ "barber chair": 423,
1308
+ "barbershop": 424,
1309
+ "barn": 425,
1310
+ "barn spider, Araneus cavaticus": 73,
1311
+ "barometer": 426,
1312
+ "barracouta, snoek": 389,
1313
+ "barrel, cask": 427,
1314
+ "barrow, garden cart, lawn cart, wheelbarrow": 428,
1315
+ "baseball": 429,
1316
+ "basenji": 253,
1317
+ "basketball": 430,
1318
+ "basset, basset hound": 161,
1319
+ "bassinet": 431,
1320
+ "bassoon": 432,
1321
+ "bath towel": 434,
1322
+ "bathing cap, swimming cap": 433,
1323
+ "bathtub, bathing tub, bath, tub": 435,
1324
+ "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon": 436,
1325
+ "beacon, lighthouse, beacon light, pharos": 437,
1326
+ "beagle": 162,
1327
+ "beaker": 438,
1328
+ "bearskin, busby, shako": 439,
1329
+ "beaver": 337,
1330
+ "bee": 309,
1331
+ "bee eater": 92,
1332
+ "beer bottle": 440,
1333
+ "beer glass": 441,
1334
+ "bell cote, bell cot": 442,
1335
+ "bell pepper": 945,
1336
+ "bib": 443,
1337
+ "bicycle-built-for-two, tandem bicycle, tandem": 444,
1338
+ "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis": 349,
1339
+ "bikini, two-piece": 445,
1340
+ "binder, ring-binder": 446,
1341
+ "binoculars, field glasses, opera glasses": 447,
1342
+ "birdhouse": 448,
1343
+ "bison": 347,
1344
+ "bittern": 133,
1345
+ "black and gold garden spider, Argiope aurantia": 72,
1346
+ "black grouse": 80,
1347
+ "black stork, Ciconia nigra": 128,
1348
+ "black swan, Cygnus atratus": 100,
1349
+ "black widow, Latrodectus mactans": 75,
1350
+ "black-and-tan coonhound": 165,
1351
+ "black-footed ferret, ferret, Mustela nigripes": 359,
1352
+ "bloodhound, sleuthhound": 163,
1353
+ "bluetick": 164,
1354
+ "boa constrictor, Constrictor constrictor": 61,
1355
+ "boathouse": 449,
1356
+ "bobsled, bobsleigh, bob": 450,
1357
+ "bolete": 997,
1358
+ "bolo tie, bolo, bola tie, bola": 451,
1359
+ "bonnet, poke bonnet": 452,
1360
+ "book jacket, dust cover, dust jacket, dust wrapper": 921,
1361
+ "bookcase": 453,
1362
+ "bookshop, bookstore, bookstall": 454,
1363
+ "borzoi, Russian wolfhound": 169,
1364
+ "bottlecap": 455,
1365
+ "bow": 456,
1366
+ "bow tie, bow-tie, bowtie": 457,
1367
+ "box turtle, box tortoise": 37,
1368
+ "boxer": 242,
1369
+ "brain coral": 109,
1370
+ "brambling, Fringilla montifringilla": 10,
1371
+ "brass, memorial tablet, plaque": 458,
1372
+ "brassiere, bra, bandeau": 459,
1373
+ "breakwater, groin, groyne, mole, bulwark, seawall, jetty": 460,
1374
+ "breastplate, aegis, egis": 461,
1375
+ "briard": 226,
1376
+ "broccoli": 937,
1377
+ "broom": 462,
1378
+ "brown bear, bruin, Ursus arctos": 294,
1379
+ "bubble": 971,
1380
+ "bucket, pail": 463,
1381
+ "buckeye, horse chestnut, conker": 990,
1382
+ "buckle": 464,
1383
+ "bulbul": 16,
1384
+ "bull mastiff": 243,
1385
+ "bullet train, bullet": 466,
1386
+ "bulletproof vest": 465,
1387
+ "bullfrog, Rana catesbeiana": 30,
1388
+ "burrito": 965,
1389
+ "bustard": 138,
1390
+ "butcher shop, meat market": 467,
1391
+ "butternut squash": 942,
1392
+ "cab, hack, taxi, taxicab": 468,
1393
+ "cabbage butterfly": 324,
1394
+ "cairn, cairn terrier": 192,
1395
+ "caldron, cauldron": 469,
1396
+ "can opener, tin opener": 473,
1397
+ "candle, taper, wax light": 470,
1398
+ "cannon": 471,
1399
+ "canoe": 472,
1400
+ "capuchin, ringtail, Cebus capucinus": 378,
1401
+ "car mirror": 475,
1402
+ "car wheel": 479,
1403
+ "carbonara": 959,
1404
+ "cardigan": 474,
1405
+ "cardoon": 946,
1406
+ "carousel, carrousel, merry-go-round, roundabout, whirligig": 476,
1407
+ "carpenter's kit, tool kit": 477,
1408
+ "carton": 478,
1409
+ "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM": 480,
1410
+ "cassette": 481,
1411
+ "cassette player": 482,
1412
+ "castle": 483,
1413
+ "catamaran": 484,
1414
+ "cauliflower": 938,
1415
+ "cello, violoncello": 486,
1416
+ "cellular telephone, cellular phone, cellphone, cell, mobile phone": 487,
1417
+ "centipede": 79,
1418
+ "chain": 488,
1419
+ "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour": 490,
1420
+ "chain saw, chainsaw": 491,
1421
+ "chainlink fence": 489,
1422
+ "chambered nautilus, pearly nautilus, nautilus": 117,
1423
+ "cheeseburger": 933,
1424
+ "cheetah, chetah, Acinonyx jubatus": 293,
1425
+ "chest": 492,
1426
+ "chickadee": 19,
1427
+ "chiffonier, commode": 493,
1428
+ "chime, bell, gong": 494,
1429
+ "chimpanzee, chimp, Pan troglodytes": 367,
1430
+ "china cabinet, china closet": 495,
1431
+ "chiton, coat-of-mail shell, sea cradle, polyplacophore": 116,
1432
+ "chocolate sauce, chocolate syrup": 960,
1433
+ "chow, chow chow": 260,
1434
+ "church, church building": 497,
1435
+ "cicada, cicala": 316,
1436
+ "cinema, movie theater, movie theatre, movie house, picture palace": 498,
1437
+ "cleaver, meat cleaver, chopper": 499,
1438
+ "cliff dwelling": 500,
1439
+ "cliff, drop, drop-off": 972,
1440
+ "cloak": 501,
1441
+ "clog, geta, patten, sabot": 502,
1442
+ "clumber, clumber spaniel": 216,
1443
+ "cock": 7,
1444
+ "cocker spaniel, English cocker spaniel, cocker": 219,
1445
+ "cockroach, roach": 314,
1446
+ "cocktail shaker": 503,
1447
+ "coffee mug": 504,
1448
+ "coffeepot": 505,
1449
+ "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch": 391,
1450
+ "coil, spiral, volute, whorl, helix": 506,
1451
+ "collie": 231,
1452
+ "colobus, colobus monkey": 375,
1453
+ "combination lock": 507,
1454
+ "comic book": 917,
1455
+ "common iguana, iguana, Iguana iguana": 39,
1456
+ "common newt, Triturus vulgaris": 26,
1457
+ "computer keyboard, keypad": 508,
1458
+ "conch": 112,
1459
+ "confectionery, confectionary, candy store": 509,
1460
+ "consomme": 925,
1461
+ "container ship, containership, container vessel": 510,
1462
+ "convertible": 511,
1463
+ "coral fungus": 991,
1464
+ "coral reef": 973,
1465
+ "corkscrew, bottle screw": 512,
1466
+ "corn": 987,
1467
+ "cornet, horn, trumpet, trump": 513,
1468
+ "coucal": 91,
1469
+ "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor": 286,
1470
+ "cowboy boot": 514,
1471
+ "cowboy hat, ten-gallon hat": 515,
1472
+ "coyote, prairie wolf, brush wolf, Canis latrans": 272,
1473
+ "cradle": 516,
1474
+ "crane": 517,
1475
+ "crash helmet": 518,
1476
+ "crate": 519,
1477
+ "crayfish, crawfish, crawdad, crawdaddy": 124,
1478
+ "crib, cot": 520,
1479
+ "cricket": 312,
1480
+ "croquet ball": 522,
1481
+ "crossword puzzle, crossword": 918,
1482
+ "crutch": 523,
1483
+ "cucumber, cuke": 943,
1484
+ "cuirass": 524,
1485
+ "cup": 968,
1486
+ "curly-coated retriever": 206,
1487
+ "custard apple": 956,
1488
+ "daisy": 985,
1489
+ "dalmatian, coach dog, carriage dog": 251,
1490
+ "dam, dike, dyke": 525,
1491
+ "damselfly": 320,
1492
+ "desk": 526,
1493
+ "desktop computer": 527,
1494
+ "dhole, Cuon alpinus": 274,
1495
+ "dial telephone, dial phone": 528,
1496
+ "diamondback, diamondback rattlesnake, Crotalus adamanteus": 67,
1497
+ "diaper, nappy, napkin": 529,
1498
+ "digital clock": 530,
1499
+ "digital watch": 531,
1500
+ "dingo, warrigal, warragal, Canis dingo": 273,
1501
+ "dining table, board": 532,
1502
+ "dishrag, dishcloth": 533,
1503
+ "dishwasher, dish washer, dishwashing machine": 534,
1504
+ "disk brake, disc brake": 535,
1505
+ "dock, dockage, docking facility": 536,
1506
+ "dogsled, dog sled, dog sleigh": 537,
1507
+ "dome": 538,
1508
+ "doormat, welcome mat": 539,
1509
+ "dough": 961,
1510
+ "dowitcher": 142,
1511
+ "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk": 319,
1512
+ "drake": 97,
1513
+ "drilling platform, offshore rig": 540,
1514
+ "drum, membranophone, tympan": 541,
1515
+ "drumstick": 542,
1516
+ "dugong, Dugong dugon": 149,
1517
+ "dumbbell": 543,
1518
+ "dung beetle": 305,
1519
+ "ear, spike, capitulum": 998,
1520
+ "earthstar": 995,
1521
+ "echidna, spiny anteater, anteater": 102,
1522
+ "eel": 390,
1523
+ "eft": 27,
1524
+ "eggnog": 969,
1525
+ "electric fan, blower": 545,
1526
+ "electric guitar": 546,
1527
+ "electric locomotive": 547,
1528
+ "electric ray, crampfish, numbfish, torpedo": 5,
1529
+ "entertainment center": 548,
1530
+ "envelope": 549,
1531
+ "espresso": 967,
1532
+ "espresso maker": 550,
1533
+ "face powder": 551,
1534
+ "feather boa, boa": 552,
1535
+ "fiddler crab": 120,
1536
+ "fig": 952,
1537
+ "file, file cabinet, filing cabinet": 553,
1538
+ "fire engine, fire truck": 555,
1539
+ "fire screen, fireguard": 556,
1540
+ "fireboat": 554,
1541
+ "flagpole, flagstaff": 557,
1542
+ "flamingo": 130,
1543
+ "flat-coated retriever": 205,
1544
+ "flatworm, platyhelminth": 110,
1545
+ "flute, transverse flute": 558,
1546
+ "fly": 308,
1547
+ "folding chair": 559,
1548
+ "football helmet": 560,
1549
+ "forklift": 561,
1550
+ "fountain": 562,
1551
+ "fountain pen": 563,
1552
+ "four-poster": 564,
1553
+ "fox squirrel, eastern fox squirrel, Sciurus niger": 335,
1554
+ "freight car": 565,
1555
+ "frilled lizard, Chlamydosaurus kingi": 43,
1556
+ "frying pan, frypan, skillet": 567,
1557
+ "fur coat": 568,
1558
+ "gar, garfish, garpike, billfish, Lepisosteus osseus": 395,
1559
+ "garbage truck, dustcart": 569,
1560
+ "garden spider, Aranea diademata": 74,
1561
+ "garter snake, grass snake": 57,
1562
+ "gas pump, gasoline pump, petrol pump, island dispenser": 571,
1563
+ "gasmask, respirator, gas helmet": 570,
1564
+ "gazelle": 353,
1565
+ "geyser": 974,
1566
+ "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca": 388,
1567
+ "giant schnauzer": 197,
1568
+ "gibbon, Hylobates lar": 368,
1569
+ "go-kart": 573,
1570
+ "goblet": 572,
1571
+ "golden retriever": 207,
1572
+ "goldfinch, Carduelis carduelis": 11,
1573
+ "goldfish, Carassius auratus": 1,
1574
+ "golf ball": 574,
1575
+ "golfcart, golf cart": 575,
1576
+ "gondola": 576,
1577
+ "gong, tam-tam": 577,
1578
+ "goose": 99,
1579
+ "gorilla, Gorilla gorilla": 366,
1580
+ "gown": 578,
1581
+ "grand piano, grand": 579,
1582
+ "grasshopper, hopper": 311,
1583
+ "great grey owl, great gray owl, Strix nebulosa": 24,
1584
+ "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias": 2,
1585
+ "green lizard, Lacerta viridis": 46,
1586
+ "green mamba": 64,
1587
+ "green snake, grass snake": 55,
1588
+ "greenhouse, nursery, glasshouse": 580,
1589
+ "grey fox, gray fox, Urocyon cinereoargenteus": 280,
1590
+ "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus": 147,
1591
+ "grille, radiator grille": 581,
1592
+ "grocery store, grocery, food market, market": 582,
1593
+ "groenendael": 224,
1594
+ "groom, bridegroom": 982,
1595
+ "ground beetle, carabid beetle": 302,
1596
+ "guacamole": 924,
1597
+ "guenon, guenon monkey": 370,
1598
+ "guillotine": 583,
1599
+ "guinea pig, Cavia cobaya": 338,
1600
+ "gyromitra": 993,
1601
+ "hair slide": 584,
1602
+ "hair spray": 585,
1603
+ "half track": 586,
1604
+ "hammer": 587,
1605
+ "hammerhead, hammerhead shark": 4,
1606
+ "hamper": 588,
1607
+ "hamster": 333,
1608
+ "hand blower, blow dryer, blow drier, hair dryer, hair drier": 589,
1609
+ "hand-held computer, hand-held microcomputer": 590,
1610
+ "handkerchief, hankie, hanky, hankey": 591,
1611
+ "hard disc, hard disk, fixed disk": 592,
1612
+ "hare": 331,
1613
+ "harmonica, mouth organ, harp, mouth harp": 593,
1614
+ "harp": 594,
1615
+ "hartebeest": 351,
1616
+ "harvester, reaper": 595,
1617
+ "harvestman, daddy longlegs, Phalangium opilio": 70,
1618
+ "hatchet": 596,
1619
+ "hay": 958,
1620
+ "head cabbage": 936,
1621
+ "hen": 8,
1622
+ "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa": 996,
1623
+ "hermit crab": 125,
1624
+ "hip, rose hip, rosehip": 989,
1625
+ "hippopotamus, hippo, river horse, Hippopotamus amphibius": 344,
1626
+ "hog, pig, grunter, squealer, Sus scrofa": 341,
1627
+ "hognose snake, puff adder, sand viper": 54,
1628
+ "holster": 597,
1629
+ "home theater, home theatre": 598,
1630
+ "honeycomb": 599,
1631
+ "hook, claw": 600,
1632
+ "hoopskirt, crinoline": 601,
1633
+ "horizontal bar, high bar": 602,
1634
+ "hornbill": 93,
1635
+ "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus": 66,
1636
+ "horse cart, horse-cart": 603,
1637
+ "hot pot, hotpot": 926,
1638
+ "hotdog, hot dog, red hot": 934,
1639
+ "hourglass": 604,
1640
+ "house finch, linnet, Carpodacus mexicanus": 12,
1641
+ "howler monkey, howler": 379,
1642
+ "hummingbird": 94,
1643
+ "hyena, hyaena": 276,
1644
+ "iPod": 605,
1645
+ "ibex, Capra ibex": 350,
1646
+ "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus": 296,
1647
+ "ice cream, icecream": 928,
1648
+ "ice lolly, lolly, lollipop, popsicle": 929,
1649
+ "impala, Aepyceros melampus": 352,
1650
+ "indigo bunting, indigo finch, indigo bird, Passerina cyanea": 14,
1651
+ "indri, indris, Indri indri, Indri brevicaudatus": 384,
1652
+ "iron, smoothing iron": 606,
1653
+ "isopod": 126,
1654
+ "jacamar": 95,
1655
+ "jack-o'-lantern": 607,
1656
+ "jackfruit, jak, jack": 955,
1657
+ "jaguar, panther, Panthera onca, Felis onca": 290,
1658
+ "jay": 17,
1659
+ "jean, blue jean, denim": 608,
1660
+ "jeep, landrover": 609,
1661
+ "jellyfish": 107,
1662
+ "jersey, T-shirt, tee shirt": 610,
1663
+ "jigsaw puzzle": 611,
1664
+ "jinrikisha, ricksha, rickshaw": 612,
1665
+ "joystick": 613,
1666
+ "junco, snowbird": 13,
1667
+ "keeshond": 261,
1668
+ "kelpie": 227,
1669
+ "killer whale, killer, orca, grampus, sea wolf, Orcinus orca": 148,
1670
+ "kimono": 614,
1671
+ "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica": 121,
1672
+ "king penguin, Aptenodytes patagonica": 145,
1673
+ "king snake, kingsnake": 56,
1674
+ "kit fox, Vulpes macrotis": 278,
1675
+ "kite": 21,
1676
+ "knee pad": 615,
1677
+ "knot": 616,
1678
+ "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus": 105,
1679
+ "komondor": 228,
1680
+ "kuvasz": 222,
1681
+ "lab coat, laboratory coat": 617,
1682
+ "lacewing, lacewing fly": 318,
1683
+ "ladle": 618,
1684
+ "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle": 301,
1685
+ "lakeside, lakeshore": 975,
1686
+ "lampshade, lamp shade": 619,
1687
+ "langur": 374,
1688
+ "laptop, laptop computer": 620,
1689
+ "lawn mower, mower": 621,
1690
+ "leaf beetle, chrysomelid": 304,
1691
+ "leafhopper": 317,
1692
+ "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea": 34,
1693
+ "lemon": 951,
1694
+ "lens cap, lens cover": 622,
1695
+ "leopard, Panthera pardus": 288,
1696
+ "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens": 387,
1697
+ "letter opener, paper knife, paperknife": 623,
1698
+ "library": 624,
1699
+ "lifeboat": 625,
1700
+ "lighter, light, igniter, ignitor": 626,
1701
+ "limousine, limo": 627,
1702
+ "limpkin, Aramus pictus": 135,
1703
+ "liner, ocean liner": 628,
1704
+ "lion, king of beasts, Panthera leo": 291,
1705
+ "lionfish": 396,
1706
+ "lipstick, lip rouge": 629,
1707
+ "little blue heron, Egretta caerulea": 131,
1708
+ "llama": 355,
1709
+ "loggerhead, loggerhead turtle, Caretta caretta": 33,
1710
+ "long-horned beetle, longicorn, longicorn beetle": 303,
1711
+ "lorikeet": 90,
1712
+ "lotion": 631,
1713
+ "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system": 632,
1714
+ "loupe, jeweler's loupe": 633,
1715
+ "lumbermill, sawmill": 634,
1716
+ "lycaenid, lycaenid butterfly": 326,
1717
+ "lynx, catamount": 287,
1718
+ "macaque": 373,
1719
+ "macaw": 88,
1720
+ "magnetic compass": 635,
1721
+ "magpie": 18,
1722
+ "mailbag, postbag": 636,
1723
+ "mailbox, letter box": 637,
1724
+ "maillot": 638,
1725
+ "maillot, tank suit": 639,
1726
+ "malamute, malemute, Alaskan malamute": 249,
1727
+ "malinois": 225,
1728
+ "manhole cover": 640,
1729
+ "mantis, mantid": 315,
1730
+ "maraca": 641,
1731
+ "marimba, xylophone": 642,
1732
+ "marmoset": 377,
1733
+ "marmot": 336,
1734
+ "mashed potato": 935,
1735
+ "mask": 643,
1736
+ "matchstick": 644,
1737
+ "maypole": 645,
1738
+ "maze, labyrinth": 646,
1739
+ "measuring cup": 647,
1740
+ "meat loaf, meatloaf": 962,
1741
+ "medicine chest, medicine cabinet": 648,
1742
+ "meerkat, mierkat": 299,
1743
+ "megalith, megalithic structure": 649,
1744
+ "menu": 922,
1745
+ "microphone, mike": 650,
1746
+ "microwave, microwave oven": 651,
1747
+ "military uniform": 652,
1748
+ "milk can": 653,
1749
+ "miniature pinscher": 237,
1750
+ "miniature poodle": 266,
1751
+ "miniature schnauzer": 196,
1752
+ "minibus": 654,
1753
+ "miniskirt, mini": 655,
1754
+ "minivan": 656,
1755
+ "mink": 357,
1756
+ "missile": 657,
1757
+ "mitten": 658,
1758
+ "mixing bowl": 659,
1759
+ "mobile home, manufactured home": 660,
1760
+ "modem": 662,
1761
+ "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus": 323,
1762
+ "monastery": 663,
1763
+ "mongoose": 298,
1764
+ "monitor": 664,
1765
+ "moped": 665,
1766
+ "mortar": 666,
1767
+ "mortarboard": 667,
1768
+ "mosque": 668,
1769
+ "mosquito net": 669,
1770
+ "motor scooter, scooter": 670,
1771
+ "mountain bike, all-terrain bike, off-roader": 671,
1772
+ "mountain tent": 672,
1773
+ "mouse, computer mouse": 673,
1774
+ "mousetrap": 674,
1775
+ "moving van": 675,
1776
+ "mud turtle": 35,
1777
+ "mushroom": 947,
1778
+ "muzzle": 676,
1779
+ "nail": 677,
1780
+ "neck brace": 678,
1781
+ "necklace": 679,
1782
+ "nematode, nematode worm, roundworm": 111,
1783
+ "night snake, Hypsiglena torquata": 60,
1784
+ "nipple": 680,
1785
+ "notebook, notebook computer": 681,
1786
+ "obelisk": 682,
1787
+ "oboe, hautboy, hautbois": 683,
1788
+ "ocarina, sweet potato": 684,
1789
+ "odometer, hodometer, mileometer, milometer": 685,
1790
+ "oil filter": 686,
1791
+ "orange": 950,
1792
+ "orangutan, orang, orangutang, Pongo pygmaeus": 365,
1793
+ "organ, pipe organ": 687,
1794
+ "oscilloscope, scope, cathode-ray oscilloscope, CRO": 688,
1795
+ "ostrich, Struthio camelus": 9,
1796
+ "otter": 360,
1797
+ "otterhound, otter hound": 175,
1798
+ "overskirt": 689,
1799
+ "ox": 345,
1800
+ "oxcart": 690,
1801
+ "oxygen mask": 691,
1802
+ "oystercatcher, oyster catcher": 143,
1803
+ "packet": 692,
1804
+ "paddle, boat paddle": 693,
1805
+ "paddlewheel, paddle wheel": 694,
1806
+ "padlock": 695,
1807
+ "paintbrush": 696,
1808
+ "pajama, pyjama, pj's, jammies": 697,
1809
+ "palace": 698,
1810
+ "panpipe, pandean pipe, syrinx": 699,
1811
+ "paper towel": 700,
1812
+ "papillon": 157,
1813
+ "parachute, chute": 701,
1814
+ "parallel bars, bars": 702,
1815
+ "park bench": 703,
1816
+ "parking meter": 704,
1817
+ "partridge": 86,
1818
+ "passenger car, coach, carriage": 705,
1819
+ "patas, hussar monkey, Erythrocebus patas": 371,
1820
+ "patio, terrace": 706,
1821
+ "pay-phone, pay-station": 707,
1822
+ "peacock": 84,
1823
+ "pedestal, plinth, footstall": 708,
1824
+ "pelican": 144,
1825
+ "pencil box, pencil case": 709,
1826
+ "pencil sharpener": 710,
1827
+ "perfume, essence": 711,
1828
+ "photocopier": 713,
1829
+ "pick, plectrum, plectron": 714,
1830
+ "pickelhaube": 715,
1831
+ "picket fence, paling": 716,
1832
+ "pickup, pickup truck": 717,
1833
+ "pier": 718,
1834
+ "piggy bank, penny bank": 719,
1835
+ "pill bottle": 720,
1836
+ "pillow": 721,
1837
+ "pineapple, ananas": 953,
1838
+ "ping-pong ball": 722,
1839
+ "pinwheel": 723,
1840
+ "pirate, pirate ship": 724,
1841
+ "pitcher, ewer": 725,
1842
+ "pizza, pizza pie": 963,
1843
+ "plane, carpenter's plane, woodworking plane": 726,
1844
+ "planetarium": 727,
1845
+ "plastic bag": 728,
1846
+ "plate": 923,
1847
+ "plate rack": 729,
1848
+ "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus": 103,
1849
+ "plow, plough": 730,
1850
+ "plunger, plumber's helper": 731,
1851
+ "pole": 733,
1852
+ "polecat, fitch, foulmart, foumart, Mustela putorius": 358,
1853
+ "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria": 734,
1854
+ "pomegranate": 957,
1855
+ "poncho": 735,
1856
+ "pool table, billiard table, snooker table": 736,
1857
+ "pop bottle, soda bottle": 737,
1858
+ "porcupine, hedgehog": 334,
1859
+ "pot, flowerpot": 738,
1860
+ "potpie": 964,
1861
+ "potter's wheel": 739,
1862
+ "power drill": 740,
1863
+ "prairie chicken, prairie grouse, prairie fowl": 83,
1864
+ "prayer rug, prayer mat": 741,
1865
+ "pretzel": 932,
1866
+ "printer": 742,
1867
+ "prison, prison house": 743,
1868
+ "proboscis monkey, Nasalis larvatus": 376,
1869
+ "projectile, missile": 744,
1870
+ "projector": 745,
1871
+ "promontory, headland, head, foreland": 976,
1872
+ "ptarmigan": 81,
1873
+ "puck, hockey puck": 746,
1874
+ "puffer, pufferfish, blowfish, globefish": 397,
1875
+ "pug, pug-dog": 254,
1876
+ "punching bag, punch bag, punching ball, punchball": 747,
1877
+ "purse": 748,
1878
+ "quail": 85,
1879
+ "quill, quill pen": 749,
1880
+ "quilt, comforter, comfort, puff": 750,
1881
+ "racer, race car, racing car": 751,
1882
+ "racket, racquet": 752,
1883
+ "radiator": 753,
1884
+ "radio telescope, radio reflector": 755,
1885
+ "radio, wireless": 754,
1886
+ "rain barrel": 756,
1887
+ "ram, tup": 348,
1888
+ "rapeseed": 984,
1889
+ "recreational vehicle, RV, R.V.": 757,
1890
+ "red fox, Vulpes vulpes": 277,
1891
+ "red wine": 966,
1892
+ "red wolf, maned wolf, Canis rufus, Canis niger": 271,
1893
+ "red-backed sandpiper, dunlin, Erolia alpina": 140,
1894
+ "red-breasted merganser, Mergus serrator": 98,
1895
+ "redbone": 168,
1896
+ "redshank, Tringa totanus": 141,
1897
+ "reel": 758,
1898
+ "reflex camera": 759,
1899
+ "refrigerator, icebox": 760,
1900
+ "remote control, remote": 761,
1901
+ "restaurant, eating house, eating place, eatery": 762,
1902
+ "revolver, six-gun, six-shooter": 763,
1903
+ "rhinoceros beetle": 306,
1904
+ "rifle": 764,
1905
+ "ringlet, ringlet butterfly": 322,
1906
+ "ringneck snake, ring-necked snake, ring snake": 53,
1907
+ "robin, American robin, Turdus migratorius": 15,
1908
+ "rock beauty, Holocanthus tricolor": 392,
1909
+ "rock crab, Cancer irroratus": 119,
1910
+ "rock python, rock snake, Python sebae": 62,
1911
+ "rocking chair, rocker": 765,
1912
+ "rotisserie": 766,
1913
+ "rubber eraser, rubber, pencil eraser": 767,
1914
+ "ruddy turnstone, Arenaria interpres": 139,
1915
+ "ruffed grouse, partridge, Bonasa umbellus": 82,
1916
+ "rugby ball": 768,
1917
+ "rule, ruler": 769,
1918
+ "running shoe": 770,
1919
+ "safe": 771,
1920
+ "safety pin": 772,
1921
+ "saltshaker, salt shaker": 773,
1922
+ "sandal": 774,
1923
+ "sandbar, sand bar": 977,
1924
+ "sarong": 775,
1925
+ "sax, saxophone": 776,
1926
+ "scabbard": 777,
1927
+ "scale, weighing machine": 778,
1928
+ "schipperke": 223,
1929
+ "school bus": 779,
1930
+ "schooner": 780,
1931
+ "scoreboard": 781,
1932
+ "scorpion": 71,
1933
+ "screen, CRT screen": 782,
1934
+ "screw": 783,
1935
+ "screwdriver": 784,
1936
+ "scuba diver": 983,
1937
+ "sea anemone, anemone": 108,
1938
+ "sea cucumber, holothurian": 329,
1939
+ "sea lion": 150,
1940
+ "sea slug, nudibranch": 115,
1941
+ "sea snake": 65,
1942
+ "sea urchin": 328,
1943
+ "seashore, coast, seacoast, sea-coast": 978,
1944
+ "seat belt, seatbelt": 785,
1945
+ "sewing machine": 786,
1946
+ "shield, buckler": 787,
1947
+ "shoe shop, shoe-shop, shoe store": 788,
1948
+ "shoji": 789,
1949
+ "shopping basket": 790,
1950
+ "shopping cart": 791,
1951
+ "shovel": 792,
1952
+ "shower cap": 793,
1953
+ "shower curtain": 794,
1954
+ "siamang, Hylobates syndactylus, Symphalangus syndactylus": 369,
1955
+ "sidewinder, horned rattlesnake, Crotalus cerastes": 68,
1956
+ "silky terrier, Sydney silky": 201,
1957
+ "ski": 795,
1958
+ "ski mask": 796,
1959
+ "skunk, polecat, wood pussy": 361,
1960
+ "sleeping bag": 797,
1961
+ "slide rule, slipstick": 798,
1962
+ "sliding door": 799,
1963
+ "slot, one-armed bandit": 800,
1964
+ "sloth bear, Melursus ursinus, Ursus ursinus": 297,
1965
+ "slug": 114,
1966
+ "snail": 113,
1967
+ "snorkel": 801,
1968
+ "snow leopard, ounce, Panthera uncia": 289,
1969
+ "snowmobile": 802,
1970
+ "snowplow, snowplough": 803,
1971
+ "soap dispenser": 804,
1972
+ "soccer ball": 805,
1973
+ "sock": 806,
1974
+ "soft-coated wheaten terrier": 202,
1975
+ "solar dish, solar collector, solar furnace": 807,
1976
+ "sombrero": 808,
1977
+ "sorrel": 339,
1978
+ "soup bowl": 809,
1979
+ "space bar": 810,
1980
+ "space heater": 811,
1981
+ "space shuttle": 812,
1982
+ "spaghetti squash": 940,
1983
+ "spatula": 813,
1984
+ "speedboat": 814,
1985
+ "spider monkey, Ateles geoffroyi": 381,
1986
+ "spider web, spider's web": 815,
1987
+ "spindle": 816,
1988
+ "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish": 123,
1989
+ "spoonbill": 129,
1990
+ "sports car, sport car": 817,
1991
+ "spotlight, spot": 818,
1992
+ "spotted salamander, Ambystoma maculatum": 28,
1993
+ "squirrel monkey, Saimiri sciureus": 382,
1994
+ "stage": 819,
1995
+ "standard poodle": 267,
1996
+ "standard schnauzer": 198,
1997
+ "starfish, sea star": 327,
1998
+ "steam locomotive": 820,
1999
+ "steel arch bridge": 821,
2000
+ "steel drum": 822,
2001
+ "stethoscope": 823,
2002
+ "stingray": 6,
2003
+ "stinkhorn, carrion fungus": 994,
2004
+ "stole": 824,
2005
+ "stone wall": 825,
2006
+ "stopwatch, stop watch": 826,
2007
+ "stove": 827,
2008
+ "strainer": 828,
2009
+ "strawberry": 949,
2010
+ "street sign": 919,
2011
+ "streetcar, tram, tramcar, trolley, trolley car": 829,
2012
+ "stretcher": 830,
2013
+ "studio couch, day bed": 831,
2014
+ "stupa, tope": 832,
2015
+ "sturgeon": 394,
2016
+ "submarine, pigboat, sub, U-boat": 833,
2017
+ "suit, suit of clothes": 834,
2018
+ "sulphur butterfly, sulfur butterfly": 325,
2019
+ "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita": 89,
2020
+ "sundial": 835,
2021
+ "sunglass": 836,
2022
+ "sunglasses, dark glasses, shades": 837,
2023
+ "sunscreen, sunblock, sun blocker": 838,
2024
+ "suspension bridge": 839,
2025
+ "swab, swob, mop": 840,
2026
+ "sweatshirt": 841,
2027
+ "swimming trunks, bathing trunks": 842,
2028
+ "swing": 843,
2029
+ "switch, electric switch, electrical switch": 844,
2030
+ "syringe": 845,
2031
+ "tabby, tabby cat": 281,
2032
+ "table lamp": 846,
2033
+ "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui": 32,
2034
+ "tank, army tank, armored combat vehicle, armoured combat vehicle": 847,
2035
+ "tape player": 848,
2036
+ "tarantula": 76,
2037
+ "teapot": 849,
2038
+ "teddy, teddy bear": 850,
2039
+ "television, television system": 851,
2040
+ "tench, Tinca tinca": 0,
2041
+ "tennis ball": 852,
2042
+ "terrapin": 36,
2043
+ "thatch, thatched roof": 853,
2044
+ "theater curtain, theatre curtain": 854,
2045
+ "thimble": 855,
2046
+ "three-toed sloth, ai, Bradypus tridactylus": 364,
2047
+ "thresher, thrasher, threshing machine": 856,
2048
+ "throne": 857,
2049
+ "thunder snake, worm snake, Carphophis amoenus": 52,
2050
+ "tick": 78,
2051
+ "tiger beetle": 300,
2052
+ "tiger cat": 282,
2053
+ "tiger shark, Galeocerdo cuvieri": 3,
2054
+ "tiger, Panthera tigris": 292,
2055
+ "tile roof": 858,
2056
+ "timber wolf, grey wolf, gray wolf, Canis lupus": 269,
2057
+ "titi, titi monkey": 380,
2058
+ "toaster": 859,
2059
+ "tobacco shop, tobacconist shop, tobacconist": 860,
2060
+ "toilet seat": 861,
2061
+ "toilet tissue, toilet paper, bathroom tissue": 999,
2062
+ "torch": 862,
2063
+ "totem pole": 863,
2064
+ "toucan": 96,
2065
+ "tow truck, tow car, wrecker": 864,
2066
+ "toy poodle": 265,
2067
+ "toy terrier": 158,
2068
+ "toyshop": 865,
2069
+ "tractor": 866,
2070
+ "traffic light, traffic signal, stoplight": 920,
2071
+ "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi": 867,
2072
+ "tray": 868,
2073
+ "tree frog, tree-frog": 31,
2074
+ "trench coat": 869,
2075
+ "triceratops": 51,
2076
+ "tricycle, trike, velocipede": 870,
2077
+ "trifle": 927,
2078
+ "trilobite": 69,
2079
+ "trimaran": 871,
2080
+ "tripod": 872,
2081
+ "triumphal arch": 873,
2082
+ "trolleybus, trolley coach, trackless trolley": 874,
2083
+ "trombone": 875,
2084
+ "tub, vat": 876,
2085
+ "turnstile": 877,
2086
+ "tusker": 101,
2087
+ "typewriter keyboard": 878,
2088
+ "umbrella": 879,
2089
+ "unicycle, monocycle": 880,
2090
+ "upright, upright piano": 881,
2091
+ "vacuum, vacuum cleaner": 882,
2092
+ "valley, vale": 979,
2093
+ "vase": 883,
2094
+ "vault": 884,
2095
+ "velvet": 885,
2096
+ "vending machine": 886,
2097
+ "vestment": 887,
2098
+ "viaduct": 888,
2099
+ "vine snake": 59,
2100
+ "violin, fiddle": 889,
2101
+ "vizsla, Hungarian pointer": 211,
2102
+ "volcano": 980,
2103
+ "volleyball": 890,
2104
+ "vulture": 23,
2105
+ "waffle iron": 891,
2106
+ "walking stick, walkingstick, stick insect": 313,
2107
+ "wall clock": 892,
2108
+ "wallaby, brush kangaroo": 104,
2109
+ "wallet, billfold, notecase, pocketbook": 893,
2110
+ "wardrobe, closet, press": 894,
2111
+ "warplane, military plane": 895,
2112
+ "warthog": 343,
2113
+ "washbasin, handbasin, washbowl, lavabo, wash-hand basin": 896,
2114
+ "washer, automatic washer, washing machine": 897,
2115
+ "water bottle": 898,
2116
+ "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis": 346,
2117
+ "water jug": 899,
2118
+ "water ouzel, dipper": 20,
2119
+ "water snake": 58,
2120
+ "water tower": 900,
2121
+ "weasel": 356,
2122
+ "web site, website, internet site, site": 916,
2123
+ "weevil": 307,
2124
+ "whippet": 172,
2125
+ "whiptail, whiptail lizard": 41,
2126
+ "whiskey jug": 901,
2127
+ "whistle": 902,
2128
+ "white stork, Ciconia ciconia": 127,
2129
+ "white wolf, Arctic wolf, Canis lupus tundrarum": 270,
2130
+ "wig": 903,
2131
+ "wild boar, boar, Sus scrofa": 342,
2132
+ "window screen": 904,
2133
+ "window shade": 905,
2134
+ "wine bottle": 907,
2135
+ "wing": 908,
2136
+ "wire-haired fox terrier": 188,
2137
+ "wok": 909,
2138
+ "wolf spider, hunting spider": 77,
2139
+ "wombat": 106,
2140
+ "wood rabbit, cottontail, cottontail rabbit": 330,
2141
+ "wooden spoon": 910,
2142
+ "wool, woolen, woollen": 911,
2143
+ "worm fence, snake fence, snake-rail fence, Virginia fence": 912,
2144
+ "wreck": 913,
2145
+ "yawl": 914,
2146
+ "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum": 986,
2147
+ "yurt": 915,
2148
+ "zebra": 340,
2149
+ "zucchini, courgette": 939
2150
+ },
2151
+ "layer_norm_eps": 1e-05,
2152
+ "length_penalty": 1.0,
2153
+ "max_length": 20,
2154
+ "min_length": 0,
2155
+ "mlp_ratio": 4.0,
2156
+ "model_type": "swin",
2157
+ "no_repeat_ngram_size": 0,
2158
+ "num_beam_groups": 1,
2159
+ "num_beams": 1,
2160
+ "num_channels": 3,
2161
+ "num_heads": [
2162
+ 6,
2163
+ 12,
2164
+ 24,
2165
+ 48
2166
+ ],
2167
+ "num_layers": 4,
2168
+ "num_return_sequences": 1,
2169
+ "out_features": [
2170
+ "stage1",
2171
+ "stage2",
2172
+ "stage3",
2173
+ "stage4"
2174
+ ],
2175
+ "out_indices": [
2176
+ 1,
2177
+ 2,
2178
+ 3,
2179
+ 4
2180
+ ],
2181
+ "output_attentions": false,
2182
+ "output_hidden_states": false,
2183
+ "output_scores": false,
2184
+ "pad_token_id": null,
2185
+ "patch_size": 4,
2186
+ "path_norm": true,
2187
+ "prefix": null,
2188
+ "problem_type": null,
2189
+ "pruned_heads": {},
2190
+ "qkv_bias": true,
2191
+ "remove_invalid_values": false,
2192
+ "repetition_penalty": 1.0,
2193
+ "return_dict": true,
2194
+ "return_dict_in_generate": false,
2195
+ "sep_token_id": null,
2196
+ "stage_names": [
2197
+ "stem",
2198
+ "stage1",
2199
+ "stage2",
2200
+ "stage3",
2201
+ "stage4"
2202
+ ],
2203
+ "suppress_tokens": null,
2204
+ "task_specific_params": null,
2205
+ "temperature": 1.0,
2206
+ "tf_legacy_loss": false,
2207
+ "tie_encoder_decoder": false,
2208
+ "tie_word_embeddings": true,
2209
+ "tokenizer_class": null,
2210
+ "top_k": 50,
2211
+ "top_p": 1.0,
2212
+ "torch_dtype": "float32",
2213
+ "torchscript": false,
2214
+ "typical_p": 1.0,
2215
+ "use_absolute_embeddings": false,
2216
+ "use_bfloat16": false,
2217
+ "window_size": 12
2218
+ },
2219
+ "backbone_kwargs": null,
2220
+ "bad_words_ids": null,
2221
+ "begin_suppress_tokens": null,
2222
+ "bos_token_id": null,
2223
+ "chunk_size_feed_forward": 0,
2224
+ "class_weight": 2.0,
2225
+ "common_stride": 4,
2226
+ "cross_attention_hidden_size": null,
2227
+ "decoder_layers": 10,
2228
+ "decoder_start_token_id": null,
2229
+ "dice_weight": 5.0,
2230
+ "dim_feedforward": 2048,
2231
+ "diversity_penalty": 0.0,
2232
+ "do_sample": false,
2233
+ "dropout": 0.0,
2234
+ "early_stopping": false,
2235
+ "encoder_feedforward_dim": 1024,
2236
+ "encoder_layers": 6,
2237
+ "encoder_no_repeat_ngram_size": 0,
2238
+ "enforce_input_proj": false,
2239
+ "enforce_input_projection": false,
2240
+ "eos_token_id": null,
2241
+ "exponential_decay_length_penalty": null,
2242
+ "feature_size": 256,
2243
+ "feature_strides": [
2244
+ 4,
2245
+ 8,
2246
+ 16,
2247
+ 32
2248
+ ],
2249
+ "finetuning_task": null,
2250
+ "forced_bos_token_id": null,
2251
+ "forced_eos_token_id": null,
2252
+ "hidden_dim": 256,
2253
+ "id2label": {
2254
+ "0": "person",
2255
+ "1": "bicycle",
2256
+ "2": "car",
2257
+ "3": "motorcycle",
2258
+ "4": "airplane",
2259
+ "5": "bus",
2260
+ "6": "train",
2261
+ "7": "truck",
2262
+ "8": "boat",
2263
+ "9": "traffic light",
2264
+ "10": "fire hydrant",
2265
+ "11": "stop sign",
2266
+ "12": "parking meter",
2267
+ "13": "bench",
2268
+ "14": "bird",
2269
+ "15": "cat",
2270
+ "16": "dog",
2271
+ "17": "horse",
2272
+ "18": "sheep",
2273
+ "19": "cow",
2274
+ "20": "elephant",
2275
+ "21": "bear",
2276
+ "22": "zebra",
2277
+ "23": "giraffe",
2278
+ "24": "backpack",
2279
+ "25": "umbrella",
2280
+ "26": "handbag",
2281
+ "27": "tie",
2282
+ "28": "suitcase",
2283
+ "29": "frisbee",
2284
+ "30": "skis",
2285
+ "31": "snowboard",
2286
+ "32": "sports ball",
2287
+ "33": "kite",
2288
+ "34": "baseball bat",
2289
+ "35": "baseball glove",
2290
+ "36": "skateboard",
2291
+ "37": "surfboard",
2292
+ "38": "tennis racket",
2293
+ "39": "bottle",
2294
+ "40": "wine glass",
2295
+ "41": "cup",
2296
+ "42": "fork",
2297
+ "43": "knife",
2298
+ "44": "spoon",
2299
+ "45": "bowl",
2300
+ "46": "banana",
2301
+ "47": "apple",
2302
+ "48": "sandwich",
2303
+ "49": "orange",
2304
+ "50": "broccoli",
2305
+ "51": "carrot",
2306
+ "52": "hot dog",
2307
+ "53": "pizza",
2308
+ "54": "donut",
2309
+ "55": "cake",
2310
+ "56": "chair",
2311
+ "57": "couch",
2312
+ "58": "potted plant",
2313
+ "59": "bed",
2314
+ "60": "dining table",
2315
+ "61": "toilet",
2316
+ "62": "tv",
2317
+ "63": "laptop",
2318
+ "64": "mouse",
2319
+ "65": "remote",
2320
+ "66": "keyboard",
2321
+ "67": "cell phone",
2322
+ "68": "microwave",
2323
+ "69": "oven",
2324
+ "70": "toaster",
2325
+ "71": "sink",
2326
+ "72": "refrigerator",
2327
+ "73": "book",
2328
+ "74": "clock",
2329
+ "75": "vase",
2330
+ "76": "scissors",
2331
+ "77": "teddy bear",
2332
+ "78": "hair drier",
2333
+ "79": "toothbrush",
2334
+ "80": "banner",
2335
+ "81": "blanket",
2336
+ "82": "bridge",
2337
+ "83": "cardboard",
2338
+ "84": "counter",
2339
+ "85": "curtain",
2340
+ "86": "door-stuff",
2341
+ "87": "floor-wood",
2342
+ "88": "flower",
2343
+ "89": "fruit",
2344
+ "90": "gravel",
2345
+ "91": "house",
2346
+ "92": "light",
2347
+ "93": "mirror-stuff",
2348
+ "94": "net",
2349
+ "95": "pillow",
2350
+ "96": "platform",
2351
+ "97": "playingfield",
2352
+ "98": "railroad",
2353
+ "99": "river",
2354
+ "100": "road",
2355
+ "101": "roof",
2356
+ "102": "sand",
2357
+ "103": "sea",
2358
+ "104": "shelf",
2359
+ "105": "snow",
2360
+ "106": "stairs",
2361
+ "107": "tent",
2362
+ "108": "towel",
2363
+ "109": "wall-brick",
2364
+ "110": "wall-stone",
2365
+ "111": "wall-tile",
2366
+ "112": "wall-wood",
2367
+ "113": "water-other",
2368
+ "114": "window-blind",
2369
+ "115": "window-other",
2370
+ "116": "tree-merged",
2371
+ "117": "fence-merged",
2372
+ "118": "ceiling-merged",
2373
+ "119": "sky-other-merged",
2374
+ "120": "cabinet-merged",
2375
+ "121": "table-merged",
2376
+ "122": "floor-other-merged",
2377
+ "123": "pavement-merged",
2378
+ "124": "mountain-merged",
2379
+ "125": "grass-merged",
2380
+ "126": "dirt-merged",
2381
+ "127": "paper-merged",
2382
+ "128": "food-other-merged",
2383
+ "129": "building-other-merged",
2384
+ "130": "rock-merged",
2385
+ "131": "wall-other-merged",
2386
+ "132": "rug-merged"
2387
+ },
2388
+ "ignore_value": 255,
2389
+ "importance_sample_ratio": 0.75,
2390
+ "init_std": 0.02,
2391
+ "init_xavier_std": 1.0,
2392
+ "is_decoder": false,
2393
+ "is_encoder_decoder": false,
2394
+ "label2id": {
2395
+ "airplane": 4,
2396
+ "apple": 47,
2397
+ "backpack": 24,
2398
+ "banana": 46,
2399
+ "banner": 80,
2400
+ "baseball bat": 34,
2401
+ "baseball glove": 35,
2402
+ "bear": 21,
2403
+ "bed": 59,
2404
+ "bench": 13,
2405
+ "bicycle": 1,
2406
+ "bird": 14,
2407
+ "blanket": 81,
2408
+ "boat": 8,
2409
+ "book": 73,
2410
+ "bottle": 39,
2411
+ "bowl": 45,
2412
+ "bridge": 82,
2413
+ "broccoli": 50,
2414
+ "building-other-merged": 129,
2415
+ "bus": 5,
2416
+ "cabinet-merged": 120,
2417
+ "cake": 55,
2418
+ "car": 2,
2419
+ "cardboard": 83,
2420
+ "carrot": 51,
2421
+ "cat": 15,
2422
+ "ceiling-merged": 118,
2423
+ "cell phone": 67,
2424
+ "chair": 56,
2425
+ "clock": 74,
2426
+ "couch": 57,
2427
+ "counter": 84,
2428
+ "cow": 19,
2429
+ "cup": 41,
2430
+ "curtain": 85,
2431
+ "dining table": 60,
2432
+ "dirt-merged": 126,
2433
+ "dog": 16,
2434
+ "donut": 54,
2435
+ "door-stuff": 86,
2436
+ "elephant": 20,
2437
+ "fence-merged": 117,
2438
+ "fire hydrant": 10,
2439
+ "floor-other-merged": 122,
2440
+ "floor-wood": 87,
2441
+ "flower": 88,
2442
+ "food-other-merged": 128,
2443
+ "fork": 42,
2444
+ "frisbee": 29,
2445
+ "fruit": 89,
2446
+ "giraffe": 23,
2447
+ "grass-merged": 125,
2448
+ "gravel": 90,
2449
+ "hair drier": 78,
2450
+ "handbag": 26,
2451
+ "horse": 17,
2452
+ "hot dog": 52,
2453
+ "house": 91,
2454
+ "keyboard": 66,
2455
+ "kite": 33,
2456
+ "knife": 43,
2457
+ "laptop": 63,
2458
+ "light": 92,
2459
+ "microwave": 68,
2460
+ "mirror-stuff": 93,
2461
+ "motorcycle": 3,
2462
+ "mountain-merged": 124,
2463
+ "mouse": 64,
2464
+ "net": 94,
2465
+ "orange": 49,
2466
+ "oven": 69,
2467
+ "paper-merged": 127,
2468
+ "parking meter": 12,
2469
+ "pavement-merged": 123,
2470
+ "person": 0,
2471
+ "pillow": 95,
2472
+ "pizza": 53,
2473
+ "platform": 96,
2474
+ "playingfield": 97,
2475
+ "potted plant": 58,
2476
+ "railroad": 98,
2477
+ "refrigerator": 72,
2478
+ "remote": 65,
2479
+ "river": 99,
2480
+ "road": 100,
2481
+ "rock-merged": 130,
2482
+ "roof": 101,
2483
+ "rug-merged": 132,
2484
+ "sand": 102,
2485
+ "sandwich": 48,
2486
+ "scissors": 76,
2487
+ "sea": 103,
2488
+ "sheep": 18,
2489
+ "shelf": 104,
2490
+ "sink": 71,
2491
+ "skateboard": 36,
2492
+ "skis": 30,
2493
+ "sky-other-merged": 119,
2494
+ "snow": 105,
2495
+ "snowboard": 31,
2496
+ "spoon": 44,
2497
+ "sports ball": 32,
2498
+ "stairs": 106,
2499
+ "stop sign": 11,
2500
+ "suitcase": 28,
2501
+ "surfboard": 37,
2502
+ "table-merged": 121,
2503
+ "teddy bear": 77,
2504
+ "tennis racket": 38,
2505
+ "tent": 107,
2506
+ "tie": 27,
2507
+ "toaster": 70,
2508
+ "toilet": 61,
2509
+ "toothbrush": 79,
2510
+ "towel": 108,
2511
+ "traffic light": 9,
2512
+ "train": 6,
2513
+ "tree-merged": 116,
2514
+ "truck": 7,
2515
+ "tv": 62,
2516
+ "umbrella": 25,
2517
+ "vase": 75,
2518
+ "wall-brick": 109,
2519
+ "wall-other-merged": 131,
2520
+ "wall-stone": 110,
2521
+ "wall-tile": 111,
2522
+ "wall-wood": 112,
2523
+ "water-other": 113,
2524
+ "window-blind": 114,
2525
+ "window-other": 115,
2526
+ "wine glass": 40,
2527
+ "zebra": 22
2528
+ },
2529
+ "length_penalty": 1.0,
2530
+ "mask_feature_size": 256,
2531
+ "mask_weight": 5.0,
2532
+ "max_length": 20,
2533
+ "min_length": 0,
2534
+ "model_type": "mask2former",
2535
+ "no_object_weight": 0.1,
2536
+ "no_repeat_ngram_size": 0,
2537
+ "num_attention_heads": 8,
2538
+ "num_beam_groups": 1,
2539
+ "num_beams": 1,
2540
+ "num_hidden_layers": 10,
2541
+ "num_queries": 200,
2542
+ "num_return_sequences": 1,
2543
+ "output_attentions": false,
2544
+ "output_auxiliary_logits": null,
2545
+ "output_hidden_states": false,
2546
+ "output_scores": false,
2547
+ "oversample_ratio": 3.0,
2548
+ "pad_token_id": null,
2549
+ "pre_norm": false,
2550
+ "prefix": null,
2551
+ "problem_type": null,
2552
+ "pruned_heads": {},
2553
+ "remove_invalid_values": false,
2554
+ "repetition_penalty": 1.0,
2555
+ "return_dict": true,
2556
+ "return_dict_in_generate": false,
2557
+ "sep_token_id": null,
2558
+ "suppress_tokens": null,
2559
+ "task_specific_params": null,
2560
+ "temperature": 1.0,
2561
+ "tf_legacy_loss": false,
2562
+ "tie_encoder_decoder": false,
2563
+ "tie_word_embeddings": true,
2564
+ "tokenizer_class": null,
2565
+ "top_k": 50,
2566
+ "top_p": 1.0,
2567
+ "torch_dtype": "float32",
2568
+ "torchscript": false,
2569
+ "train_num_points": 12544,
2570
+ "transformers_version": "4.47.0",
2571
+ "typical_p": 1.0,
2572
+ "use_auxiliary_loss": true,
2573
+ "use_bfloat16": false,
2574
+ "use_pretrained_backbone": false,
2575
+ "use_timm_backbone": false
2576
+ },
2577
+ "max_dynamic_patch": 12,
2578
+ "min_dynamic_patch": 1,
2579
+ "model_type": "sa2va_chat",
2580
+ "num_m2f_proposals": 100,
2581
+ "num_m2f_queries": 200,
2582
+ "pad2square": false,
2583
+ "ps_version": "v2",
2584
+ "select_layer": -1,
2585
+ "template": "internlm2_chat",
2586
+ "tie_word_embeddings": false,
2587
+ "torch_dtype": "bfloat16",
2588
+ "transformers_version": null,
2589
+ "use_backbone_lora": 0,
2590
+ "use_llm_lora": 0,
2591
+ "use_thumbnail": true,
2592
+ "vision_config": {
2593
+ "_attn_implementation_autoset": false,
2594
+ "_name_or_path": "",
2595
+ "add_cross_attention": false,
2596
+ "architectures": [
2597
+ "InternVisionModel"
2598
+ ],
2599
+ "attention_dropout": 0.0,
2600
+ "bad_words_ids": null,
2601
+ "begin_suppress_tokens": null,
2602
+ "bos_token_id": null,
2603
+ "chunk_size_feed_forward": 0,
2604
+ "cross_attention_hidden_size": null,
2605
+ "decoder_start_token_id": null,
2606
+ "diversity_penalty": 0.0,
2607
+ "do_sample": false,
2608
+ "drop_path_rate": 0.0,
2609
+ "dropout": 0.0,
2610
+ "early_stopping": false,
2611
+ "encoder_no_repeat_ngram_size": 0,
2612
+ "eos_token_id": null,
2613
+ "exponential_decay_length_penalty": null,
2614
+ "finetuning_task": null,
2615
+ "forced_bos_token_id": null,
2616
+ "forced_eos_token_id": null,
2617
+ "hidden_act": "gelu",
2618
+ "hidden_size": 1024,
2619
+ "id2label": {
2620
+ "0": "LABEL_0",
2621
+ "1": "LABEL_1"
2622
+ },
2623
+ "image_size": 448,
2624
+ "initializer_factor": 1.0,
2625
+ "initializer_range": 0.02,
2626
+ "intermediate_size": 4096,
2627
+ "is_decoder": false,
2628
+ "is_encoder_decoder": false,
2629
+ "label2id": {
2630
+ "LABEL_0": 0,
2631
+ "LABEL_1": 1
2632
+ },
2633
+ "layer_norm_eps": 1e-06,
2634
+ "length_penalty": 1.0,
2635
+ "max_length": 20,
2636
+ "min_length": 0,
2637
+ "model_type": "intern_vit_6b",
2638
+ "no_repeat_ngram_size": 0,
2639
+ "norm_type": "layer_norm",
2640
+ "num_attention_heads": 16,
2641
+ "num_beam_groups": 1,
2642
+ "num_beams": 1,
2643
+ "num_channels": 3,
2644
+ "num_hidden_layers": 24,
2645
+ "num_return_sequences": 1,
2646
+ "output_attentions": false,
2647
+ "output_hidden_states": false,
2648
+ "output_scores": false,
2649
+ "pad_token_id": null,
2650
+ "patch_size": 14,
2651
+ "prefix": null,
2652
+ "problem_type": null,
2653
+ "pruned_heads": {},
2654
+ "qk_normalization": false,
2655
+ "qkv_bias": true,
2656
+ "remove_invalid_values": false,
2657
+ "repetition_penalty": 1.0,
2658
+ "return_dict": true,
2659
+ "return_dict_in_generate": false,
2660
+ "sep_token_id": null,
2661
+ "suppress_tokens": null,
2662
+ "task_specific_params": null,
2663
+ "temperature": 1.0,
2664
+ "tf_legacy_loss": false,
2665
+ "tie_encoder_decoder": false,
2666
+ "tie_word_embeddings": true,
2667
+ "tokenizer_class": null,
2668
+ "top_k": 50,
2669
+ "top_p": 1.0,
2670
+ "torch_dtype": "bfloat16",
2671
+ "torchscript": false,
2672
+ "transformers_version": "4.47.0",
2673
+ "typical_p": 1.0,
2674
+ "use_bfloat16": true,
2675
+ "use_flash_attn": true
2676
+ }
2677
+ }
configuration_intern_vit.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import os
8
+ from typing import Union
9
+
10
+ from transformers.configuration_utils import PretrainedConfig
11
+ from transformers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class InternVisionConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
19
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+ Args:
25
+ num_channels (`int`, *optional*, defaults to 3):
26
+ Number of color channels in the input images (e.g., 3 for RGB).
27
+ patch_size (`int`, *optional*, defaults to 14):
28
+ The size (resolution) of each patch.
29
+ image_size (`int`, *optional*, defaults to 224):
30
+ The size (resolution) of each image.
31
+ qkv_bias (`bool`, *optional*, defaults to `False`):
32
+ Whether to add a bias to the queries and values in the self-attention layers.
33
+ hidden_size (`int`, *optional*, defaults to 3200):
34
+ Dimensionality of the encoder layers and the pooler layer.
35
+ num_attention_heads (`int`, *optional*, defaults to 25):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ intermediate_size (`int`, *optional*, defaults to 12800):
38
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
39
+ qk_normalization (`bool`, *optional*, defaults to `True`):
40
+ Whether to normalize the queries and keys in the self-attention layers.
41
+ num_hidden_layers (`int`, *optional*, defaults to 48):
42
+ Number of hidden layers in the Transformer encoder.
43
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
44
+ Whether to use flash attention mechanism.
45
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
46
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
47
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
48
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
49
+ The epsilon used by the layer normalization layers.
50
+ dropout (`float`, *optional*, defaults to 0.0):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
53
+ Dropout rate for stochastic depth.
54
+ attention_dropout (`float`, *optional*, defaults to 0.0):
55
+ The dropout ratio for the attention probabilities.
56
+ initializer_range (`float`, *optional*, defaults to 0.02):
57
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
58
+ initializer_factor (`float`, *optional*, defaults to 0.1):
59
+ A factor for layer scale.
60
+ """
61
+
62
+ model_type = 'intern_vit_6b'
63
+
64
+ def __init__(
65
+ self,
66
+ num_channels=3,
67
+ patch_size=14,
68
+ image_size=224,
69
+ qkv_bias=False,
70
+ hidden_size=3200,
71
+ num_attention_heads=25,
72
+ intermediate_size=12800,
73
+ qk_normalization=True,
74
+ num_hidden_layers=48,
75
+ use_flash_attn=True,
76
+ hidden_act='gelu',
77
+ norm_type='rms_norm',
78
+ layer_norm_eps=1e-6,
79
+ dropout=0.0,
80
+ drop_path_rate=0.0,
81
+ attention_dropout=0.0,
82
+ initializer_range=0.02,
83
+ initializer_factor=0.1,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(**kwargs)
87
+
88
+ self.hidden_size = hidden_size
89
+ self.intermediate_size = intermediate_size
90
+ self.dropout = dropout
91
+ self.drop_path_rate = drop_path_rate
92
+ self.num_hidden_layers = num_hidden_layers
93
+ self.num_attention_heads = num_attention_heads
94
+ self.num_channels = num_channels
95
+ self.patch_size = patch_size
96
+ self.image_size = image_size
97
+ self.initializer_range = initializer_range
98
+ self.initializer_factor = initializer_factor
99
+ self.attention_dropout = attention_dropout
100
+ self.layer_norm_eps = layer_norm_eps
101
+ self.hidden_act = hidden_act
102
+ self.norm_type = norm_type
103
+ self.qkv_bias = qkv_bias
104
+ self.qk_normalization = qk_normalization
105
+ self.use_flash_attn = use_flash_attn
106
+
107
+ @classmethod
108
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
109
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
110
+
111
+ if 'vision_config' in config_dict:
112
+ config_dict = config_dict['vision_config']
113
+
114
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
115
+ logger.warning(
116
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
117
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
118
+ )
119
+
120
+ return cls.from_dict(config_dict, **kwargs)
configuration_internlm2.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ InternLM2 model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24
+
25
+
26
+ # Modified from transformers.model.llama.configuration_llama.LlamaConfig
27
+ class InternLM2Config(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32
+
33
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34
+ documentation from [`PretrainedConfig`] for more information.
35
+
36
+
37
+ Args:
38
+ vocab_size (`int`, *optional*, defaults to 32000):
39
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
40
+ `inputs_ids` passed when calling [`InternLM2Model`]
41
+ hidden_size (`int`, *optional*, defaults to 4096):
42
+ Dimension of the hidden representations.
43
+ intermediate_size (`int`, *optional*, defaults to 11008):
44
+ Dimension of the MLP representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 32):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ num_key_value_heads (`int`, *optional*):
50
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54
+ by meanpooling all the original heads within that group. For more details checkout [this
55
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56
+ `num_attention_heads`.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
60
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
61
+ just in case (e.g., 512 or 1024 or 2048).
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
65
+ The epsilon used by the rms normalization layers.
66
+ use_cache (`bool`, *optional*, defaults to `True`):
67
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
68
+ relevant if `config.is_decoder=True`.
69
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
70
+ Whether to tie weight embeddings
71
+ Example:
72
+
73
+ """
74
+ model_type = 'internlm2'
75
+ _auto_class = 'AutoConfig'
76
+
77
+ def __init__( # pylint: disable=W0102
78
+ self,
79
+ vocab_size=103168,
80
+ hidden_size=4096,
81
+ intermediate_size=11008,
82
+ num_hidden_layers=32,
83
+ num_attention_heads=32,
84
+ num_key_value_heads=None,
85
+ hidden_act='silu',
86
+ max_position_embeddings=2048,
87
+ initializer_range=0.02,
88
+ rms_norm_eps=1e-6,
89
+ use_cache=True,
90
+ pad_token_id=0,
91
+ bos_token_id=1,
92
+ eos_token_id=2,
93
+ tie_word_embeddings=False,
94
+ bias=True,
95
+ rope_theta=10000,
96
+ rope_scaling=None,
97
+ attn_implementation='eager',
98
+ **kwargs,
99
+ ):
100
+ self.vocab_size = vocab_size
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.hidden_size = hidden_size
103
+ self.intermediate_size = intermediate_size
104
+ self.num_hidden_layers = num_hidden_layers
105
+ self.num_attention_heads = num_attention_heads
106
+ self.bias = bias
107
+
108
+ if num_key_value_heads is None:
109
+ num_key_value_heads = num_attention_heads
110
+ self.num_key_value_heads = num_key_value_heads
111
+
112
+ self.hidden_act = hidden_act
113
+ self.initializer_range = initializer_range
114
+ self.rms_norm_eps = rms_norm_eps
115
+ self.use_cache = use_cache
116
+ self.rope_theta = rope_theta
117
+ self.rope_scaling = rope_scaling
118
+ self._rope_scaling_validation()
119
+
120
+ self.attn_implementation = attn_implementation
121
+ if self.attn_implementation is None:
122
+ self.attn_implementation = 'eager'
123
+ super().__init__(
124
+ pad_token_id=pad_token_id,
125
+ bos_token_id=bos_token_id,
126
+ eos_token_id=eos_token_id,
127
+ tie_word_embeddings=tie_word_embeddings,
128
+ **kwargs,
129
+ )
130
+
131
+ def _rope_scaling_validation(self):
132
+ """
133
+ Validate the `rope_scaling` configuration.
134
+ """
135
+ if self.rope_scaling is None:
136
+ return
137
+
138
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
139
+ raise ValueError(
140
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
141
+ f'got {self.rope_scaling}'
142
+ )
143
+ rope_scaling_type = self.rope_scaling.get('type', None)
144
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
145
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
146
+ raise ValueError(
147
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
148
+ )
149
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
150
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
configuration_phi3.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License atd
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ Phi-3 model configuration"""
16
+
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
25
+ 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
26
+ }
27
+
28
+
29
+ class Phi3Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
32
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33
+ defaults will yield a similar configuration to that of the
34
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 32064):
41
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`Phi3Model`].
43
+ hidden_size (`int`, *optional*, defaults to 3072):
44
+ Dimension of the hidden representations.
45
+ intermediate_size (`int`, *optional*, defaults to 8192):
46
+ Dimension of the MLP representations.
47
+ num_hidden_layers (`int`, *optional*, defaults to 32):
48
+ Number of hidden layers in the Transformer decoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 32):
50
+ Number of attention heads for each attention layer in the Transformer decoder.
51
+ num_key_value_heads (`int`, *optional*):
52
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
+ by meanpooling all the original heads within that group. For more details checkout [this
57
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58
+ `num_attention_heads`.
59
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
60
+ Dropout probability for mlp outputs.
61
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the embeddings.
63
+ attention_dropout (`float`, *optional*, defaults to 0.0):
64
+ The dropout ratio after computing the attention scores.
65
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
66
+ The non-linear activation function (function or string) in the decoder.
67
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
68
+ The maximum sequence length that this model might ever be used with.
69
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
70
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
71
+ original RoPE embeddings when using long scaling.
72
+ initializer_range (`float`, *optional*, defaults to 0.02):
73
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
75
+ The epsilon value used for the RMSNorm.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie weight embeddings
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`dict`, *optional*):
84
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
85
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
86
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
87
+ divided by the number of attention heads divided by 2.
88
+ bos_token_id (`int`, *optional*, defaults to 1):
89
+ The id of the "beginning-of-sequence" token.
90
+ eos_token_id (`int`, *optional*, defaults to 32000):
91
+ The id of the "end-of-sequence" token.
92
+ pad_token_id (`int`, *optional*, defaults to 32000):
93
+ The id of the padding token.
94
+ sliding_window (`int`, *optional*):
95
+ Sliding window attention window size. If `None`, no sliding window is applied.
96
+
97
+ Example:
98
+
99
+ ```python
100
+ >>> from transformers import Phi3Model, Phi3Config
101
+
102
+ >>> # Initializing a Phi-3 style configuration
103
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
104
+
105
+ >>> # Initializing a model from the configuration
106
+ >>> model = Phi3Model(configuration)
107
+
108
+ >>> # Accessing the model configuration
109
+ >>> configuration = model.config
110
+ ```"""
111
+
112
+ model_type = 'phi3'
113
+ keys_to_ignore_at_inference = ['past_key_values']
114
+
115
+ def __init__(
116
+ self,
117
+ vocab_size=32064,
118
+ hidden_size=3072,
119
+ intermediate_size=8192,
120
+ num_hidden_layers=32,
121
+ num_attention_heads=32,
122
+ num_key_value_heads=None,
123
+ resid_pdrop=0.0,
124
+ embd_pdrop=0.0,
125
+ attention_dropout=0.0,
126
+ hidden_act='silu',
127
+ max_position_embeddings=4096,
128
+ original_max_position_embeddings=4096,
129
+ initializer_range=0.02,
130
+ rms_norm_eps=1e-5,
131
+ use_cache=True,
132
+ tie_word_embeddings=False,
133
+ rope_theta=10000.0,
134
+ rope_scaling=None,
135
+ bos_token_id=1,
136
+ eos_token_id=32000,
137
+ pad_token_id=32000,
138
+ sliding_window=None,
139
+ **kwargs,
140
+ ):
141
+ self.vocab_size = vocab_size
142
+ self.hidden_size = hidden_size
143
+ self.intermediate_size = intermediate_size
144
+ self.num_hidden_layers = num_hidden_layers
145
+ self.num_attention_heads = num_attention_heads
146
+
147
+ if num_key_value_heads is None:
148
+ num_key_value_heads = num_attention_heads
149
+
150
+ self.num_key_value_heads = num_key_value_heads
151
+ self.resid_pdrop = resid_pdrop
152
+ self.embd_pdrop = embd_pdrop
153
+ self.attention_dropout = attention_dropout
154
+ self.hidden_act = hidden_act
155
+ self.max_position_embeddings = max_position_embeddings
156
+ self.original_max_position_embeddings = original_max_position_embeddings
157
+ self.initializer_range = initializer_range
158
+ self.rms_norm_eps = rms_norm_eps
159
+ self.use_cache = use_cache
160
+ self.rope_theta = rope_theta
161
+ self.rope_scaling = rope_scaling
162
+ self._rope_scaling_validation()
163
+ self.sliding_window = sliding_window
164
+
165
+ super().__init__(
166
+ bos_token_id=bos_token_id,
167
+ eos_token_id=eos_token_id,
168
+ pad_token_id=pad_token_id,
169
+ tie_word_embeddings=tie_word_embeddings,
170
+ **kwargs,
171
+ )
172
+
173
+ def _rope_scaling_validation(self):
174
+ """
175
+ Validate the `rope_scaling` configuration.
176
+ """
177
+ if self.rope_scaling is None:
178
+ return
179
+
180
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
181
+ raise ValueError(
182
+ '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
183
+ f'got {self.rope_scaling}'
184
+ )
185
+ rope_scaling_type = self.rope_scaling.get('type', None)
186
+ rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
187
+ rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
188
+ if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
189
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
190
+ if not (
191
+ isinstance(rope_scaling_short_factor, list)
192
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
193
+ ):
194
+ raise ValueError(
195
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
196
+ )
197
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
198
+ raise ValueError(
199
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
200
+ )
201
+ if not (
202
+ isinstance(rope_scaling_long_factor, list)
203
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
204
+ ):
205
+ raise ValueError(
206
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
207
+ )
208
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
209
+ raise ValueError(
210
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
211
+ )
configuration_sa2va_chat.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import copy
8
+
9
+ from .configuration_internlm2 import InternLM2Config
10
+ from .configuration_phi3 import Phi3Config
11
+ from transformers import AutoConfig, LlamaConfig, Qwen2Config, Mask2FormerConfig
12
+ from transformers.configuration_utils import PretrainedConfig
13
+ from transformers.utils import logging
14
+
15
+ from .configuration_intern_vit import InternVisionConfig
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+
20
+ class Sa2VAChatConfig(PretrainedConfig):
21
+ model_type = 'sa2va_chat'
22
+ is_composition = True
23
+
24
+ def __init__(
25
+ self,
26
+ vision_config=None,
27
+ llm_config=None,
28
+ m2f_config=None,
29
+ use_backbone_lora=0,
30
+ use_llm_lora=0,
31
+ pad2square=False,
32
+ select_layer=-1,
33
+ force_image_size=None,
34
+ downsample_ratio=0.5,
35
+ template=None,
36
+ dynamic_image_size=False,
37
+ use_thumbnail=False,
38
+ ps_version='v1',
39
+ min_dynamic_patch=1,
40
+ max_dynamic_patch=6,
41
+ # mask2former
42
+ num_m2f_queries=300,
43
+ num_m2f_proposals=100,
44
+ **kwargs):
45
+ super().__init__(**kwargs)
46
+ if vision_config is None:
47
+ vision_config = {"architectures": ["InternVisionModel"]}
48
+ logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
49
+
50
+ if llm_config is None:
51
+ llm_config = {'architectures': ['Qwen2ForCausalLM']}
52
+ logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
53
+
54
+ if m2f_config is None:
55
+ m2f_config = {"architectures": ["SwinForImageClassification"]}
56
+ logger.info('m2f_config is None. Initializing the Mask2FormerConfig config with default values.')
57
+
58
+ self.vision_config = InternVisionConfig(**vision_config)
59
+ self.m2f_config = Mask2FormerConfig(**m2f_config)
60
+
61
+ if llm_config['architectures'][0] == 'LlamaForCausalLM':
62
+ self.llm_config = LlamaConfig(**llm_config)
63
+ elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
64
+ self.llm_config = InternLM2Config(**llm_config)
65
+ elif llm_config['architectures'][0] == 'Phi3ForCausalLM':
66
+ self.llm_config = Phi3Config(**llm_config)
67
+ elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
68
+ self.llm_config = Qwen2Config(**llm_config)
69
+ else:
70
+ raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
71
+ self.use_backbone_lora = use_backbone_lora
72
+ self.use_llm_lora = use_llm_lora
73
+ self.pad2square = pad2square
74
+ self.select_layer = select_layer
75
+ self.force_image_size = force_image_size
76
+ self.downsample_ratio = downsample_ratio
77
+ self.template = template
78
+ self.dynamic_image_size = dynamic_image_size
79
+ self.use_thumbnail = use_thumbnail
80
+ self.ps_version = ps_version # pixel shuffle version
81
+ self.min_dynamic_patch = min_dynamic_patch
82
+ self.max_dynamic_patch = max_dynamic_patch
83
+ # mask2former
84
+ self.num_m2f_queries=num_m2f_queries
85
+ self.num_m2f_proposals=num_m2f_proposals
86
+
87
+ self.hidden_size = self.llm_config.hidden_size
88
+ self.tie_word_embeddings = False
89
+
90
+ logger.info(f'vision_select_layer: {self.select_layer}')
91
+ logger.info(f'ps_version: {self.ps_version}')
92
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
93
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
94
+
95
+ def to_dict(self):
96
+ """
97
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
98
+
99
+ Returns:
100
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
101
+ """
102
+ output = copy.deepcopy(self.__dict__)
103
+ output['vision_config'] = self.vision_config.to_dict()
104
+ output['llm_config'] = self.llm_config.to_dict()
105
+ output['m2f_config'] = self.m2f_config.to_dict()
106
+ output['model_type'] = self.__class__.model_type
107
+ output['use_backbone_lora'] = self.use_backbone_lora
108
+ output['use_llm_lora'] = self.use_llm_lora
109
+ output['pad2square'] = self.pad2square
110
+ output['select_layer'] = self.select_layer
111
+ output['force_image_size'] = self.force_image_size
112
+ output['downsample_ratio'] = self.downsample_ratio
113
+ output['template'] = self.template
114
+ output['dynamic_image_size'] = self.dynamic_image_size
115
+ output['use_thumbnail'] = self.use_thumbnail
116
+ output['ps_version'] = self.ps_version
117
+ output['min_dynamic_patch'] = self.min_dynamic_patch
118
+ output['max_dynamic_patch'] = self.max_dynamic_patch
119
+ output['num_m2f_queries'] = self.num_m2f_queries
120
+ output['num_m2f_proposals'] = self.num_m2f_proposals
121
+
122
+ return output
constants.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
2
+ IMG_START_TOKEN = '<img>'
3
+ IMG_END_TOKEN = '</img>'
4
+ PHRASE_START_TOKEN = '<p>'
5
+ PHRASE_END_TOKEN = '</p>'
6
+ SEG_TOKEN = '[SEG{id}]'
7
+ CLS_TOKEN = '[CLS]'
8
+ BG_CLS_TOKEN = '[BG_CLS]'
9
+ # PROPOSAL_TOKENS = [f'[SEG{str(i).zfill(3)}]' for i in range(300)]
10
+ OBJ_START_TOKEN = '<obj>'
11
+ OBJ_END_TOKEN = '</obj>'
12
+ OBJ_CONTEXT_TOKEN = '<OBJ_CONTEXT>'
13
+ DEFAULT_OBJ_TOKEN = '<obj_tokens>'
flash_attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+
6
+ try: # v1
7
+ from flash_attn.flash_attn_interface import \
8
+ flash_attn_unpadded_qkvpacked_func
9
+ except: # v2
10
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
+
12
+ from flash_attn.bert_padding import pad_input, unpad_input
13
+
14
+
15
+ class FlashAttention(nn.Module):
16
+ """Implement the scaled dot product attention with softmax.
17
+ Arguments
18
+ ---------
19
+ softmax_scale: The temperature to use for the softmax attention.
20
+ (default: 1/sqrt(d_keys) where d_keys is computed at
21
+ runtime)
22
+ attention_dropout: The dropout rate to apply to the attention
23
+ (default: 0.0)
24
+ """
25
+
26
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
+ super().__init__()
28
+ self.softmax_scale = softmax_scale
29
+ self.dropout_p = attention_dropout
30
+
31
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
+ max_s=None, need_weights=False):
33
+ """Implements the multihead softmax attention.
34
+ Arguments
35
+ ---------
36
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
+ if unpadded: (nnz, 3, h, d)
38
+ key_padding_mask: a bool tensor of shape (B, S)
39
+ """
40
+ assert not need_weights
41
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
42
+ assert qkv.is_cuda
43
+
44
+ if cu_seqlens is None:
45
+ batch_size = qkv.shape[0]
46
+ seqlen = qkv.shape[1]
47
+ if key_padding_mask is None:
48
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
+ max_s = seqlen
50
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
+ device=qkv.device)
52
+ output = flash_attn_unpadded_qkvpacked_func(
53
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
+ softmax_scale=self.softmax_scale, causal=causal
55
+ )
56
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
+ else:
58
+ nheads = qkv.shape[-2]
59
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
63
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
+ softmax_scale=self.softmax_scale, causal=causal
65
+ )
66
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
+ indices, batch_size, seqlen),
68
+ 'b s (h d) -> b s h d', h=nheads)
69
+ else:
70
+ assert max_s is not None
71
+ output = flash_attn_unpadded_qkvpacked_func(
72
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
+ softmax_scale=self.softmax_scale, causal=causal
74
+ )
75
+
76
+ return output, None
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.47.0"
4
+ }
mask2former.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+ from transformers.models.mask2former.modeling_mask2former import (
6
+ Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerModelOutput,
7
+ Mask2FormerForUniversalSegmentationOutput, Mask2FormerMLPPredictionHead,
8
+ sample_point, pair_wise_sigmoid_cross_entropy_loss, pair_wise_dice_loss,
9
+ sigmoid_cross_entropy_loss, dice_loss)
10
+ from torch import Tensor
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.file_utils import is_scipy_available
14
+
15
+ if is_scipy_available():
16
+ from scipy.optimize import linear_sum_assignment
17
+
18
+
19
+ def get_classification_logits(x, text_classifier, logit_scale):
20
+ # x in shape of [B, *, C]
21
+ # text_classifier in shape of [num_classes, C]
22
+ # logit_scale is a learnable scalar https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/model.py#L201
23
+ # return: [B, *, num_classes]
24
+ x = F.normalize(x, dim=-1)
25
+ text_classifier = F.normalize(text_classifier, dim=-1)
26
+ logit_scale = torch.clamp(logit_scale.exp(), max=100)
27
+ pred_logits = logit_scale * x @ text_classifier.T # B, *, N + 1
28
+ return pred_logits
29
+
30
+
31
+ def _post_init(self):
32
+ self.class_embed = Mask2FormerMLPPredictionHead(self.config.hidden_dim, self.config.hidden_dim, self.config.hidden_dim, 3)
33
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
34
+
35
+
36
+ def ov_class_predictor(self, x, text_classifier):
37
+ x = self.class_embed(x)
38
+ all_pred_logits = []
39
+ for per_x, per_text_classifier in zip(x, text_classifier):
40
+ per_pred_logits = get_classification_logits(per_x.unsqueeze(0), per_text_classifier, self.logit_scale)
41
+ all_pred_logits.append(per_pred_logits.squeeze(0))
42
+
43
+ return all_pred_logits
44
+
45
+
46
+
47
+ def Mask2FormerLoss_loss_labels(
48
+ self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
49
+ ) -> Dict[str, Tensor]:
50
+ batch_size = len(class_queries_logits)
51
+ num_queries = class_queries_logits[0].shape[0]
52
+ all_ce_loss = []
53
+ for i in range(batch_size):
54
+ num_labels_plus1 = class_queries_logits[i].shape[-1]
55
+ empty_weight = torch.ones(num_labels_plus1)
56
+ empty_weight[-1] = self.eos_coef
57
+ empty_weight = empty_weight.to(class_queries_logits[i].device).to(class_queries_logits[i].dtype)
58
+ criterion = nn.CrossEntropyLoss(weight=empty_weight, reduction='none')
59
+ target_classes_o = class_labels[i][indices[i][1]]
60
+ target_classes = torch.full(
61
+ (num_queries, ), fill_value=num_labels_plus1-1, dtype=torch.int64, device=class_queries_logits[i].device)
62
+ target_classes[indices[i][0]] = target_classes_o.to(class_queries_logits[i].device)
63
+ target_classes = target_classes.unsqueeze(0)
64
+ pred_logits = class_queries_logits[i].unsqueeze(0).transpose(1, 2)
65
+ loss_ce = criterion(pred_logits, target_classes)
66
+ all_ce_loss.append(loss_ce)
67
+ losses = {"loss_cross_entropy": torch.cat(all_ce_loss, dim=-1).mean()}
68
+ return losses
69
+
70
+ def Mask2FormerLoss_loss_masks(
71
+ self,
72
+ masks_queries_logits: torch.Tensor,
73
+ mask_labels: List[torch.Tensor],
74
+ indices: Tuple[np.array],
75
+ num_masks: int
76
+ ) -> Dict[str, torch.Tensor]:
77
+ src_idx = self._get_predictions_permutation_indices(indices)
78
+ tgt_idx = self._get_targets_permutation_indices(indices)
79
+ # shape (batch_size * num_queries, height, width)
80
+ pred_masks = masks_queries_logits[src_idx]
81
+ # shape (batch_size, num_queries, height, width)
82
+ # pad all and stack the targets to the num_labels dimension
83
+ target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
84
+ target_masks = target_masks[tgt_idx]
85
+
86
+ # No need to upsample predictions as we are using normalized coordinates
87
+ pred_masks = pred_masks[:, None]
88
+ target_masks = target_masks[:, None]
89
+
90
+ # Sample point coordinates
91
+ with torch.no_grad():
92
+ point_coordinates = self.sample_points_using_uncertainty(
93
+ pred_masks,
94
+ lambda logits: self.calculate_uncertainty(logits),
95
+ self.num_points,
96
+ self.oversample_ratio,
97
+ self.importance_sample_ratio,
98
+ )
99
+ point_labels = sample_point(target_masks.to(torch.bfloat16), point_coordinates.to(torch.bfloat16), align_corners=False).squeeze(1)
100
+
101
+ point_logits = sample_point(pred_masks, point_coordinates.to(pred_masks.dtype), align_corners=False).squeeze(1)
102
+
103
+ losses = {
104
+ "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
105
+ "loss_dice": dice_loss(point_logits, point_labels, num_masks),
106
+ }
107
+
108
+ del pred_masks
109
+ del target_masks
110
+ return losses
111
+
112
+ def Mask2FormerLoss_sample_points_using_uncertainty(
113
+ self,
114
+ logits: torch.Tensor,
115
+ uncertainty_function,
116
+ num_points: int,
117
+ oversample_ratio: int,
118
+ importance_sample_ratio: float,
119
+ ) -> torch.Tensor:
120
+
121
+ num_boxes = logits.shape[0]
122
+ num_points_sampled = int(num_points * oversample_ratio)
123
+
124
+ # Get random point coordinates
125
+ point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
126
+ # Get sampled prediction value for the point coordinates
127
+ point_logits = sample_point(logits, point_coordinates.to(logits.dtype), align_corners=False)
128
+ # Calculate the uncertainties based on the sampled prediction values of the points
129
+ point_uncertainties = uncertainty_function(point_logits)
130
+
131
+ num_uncertain_points = int(importance_sample_ratio * num_points)
132
+ num_random_points = num_points - num_uncertain_points
133
+
134
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
135
+ shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
136
+ idx += shift[:, None]
137
+ point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
138
+
139
+ if num_random_points > 0:
140
+ point_coordinates = torch.cat(
141
+ [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
142
+ dim=1,
143
+ )
144
+ return point_coordinates
145
+
146
+
147
+
148
+ @torch.no_grad()
149
+ def Mask2FormerHungarianMatcher_forward(
150
+ self,
151
+ masks_queries_logits: torch.Tensor,
152
+ class_queries_logits: torch.Tensor,
153
+ mask_labels: torch.Tensor,
154
+ class_labels: torch.Tensor,
155
+ ) -> List[Tuple[Tensor]]:
156
+ indices: List[Tuple[np.array]] = []
157
+
158
+ # iterate through batch size
159
+ batch_size = masks_queries_logits.shape[0]
160
+ for i in range(batch_size):
161
+ pred_probs = class_queries_logits[i].softmax(-1)
162
+ pred_mask = masks_queries_logits[i]
163
+
164
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted.
165
+ cost_class = -pred_probs[:, class_labels[i]]
166
+ target_mask = mask_labels[i].to(pred_mask)
167
+ target_mask = target_mask[:, None]
168
+ pred_mask = pred_mask[:, None]
169
+
170
+ # Sample ground truth and predicted masks
171
+ point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
172
+
173
+ target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1).to(target_mask.dtype)
174
+ target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
175
+
176
+ pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1).to(pred_mask.dtype)
177
+ pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
178
+
179
+ # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
180
+ cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
181
+ # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
182
+ cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
183
+ # final cost matrix
184
+ cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
185
+ # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
186
+ cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
187
+ cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
188
+ cost_matrix = torch.nan_to_num(cost_matrix, 0)
189
+ # do the assigmented using the hungarian algorithm in scipy
190
+ assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.to(torch.float32).cpu())
191
+ indices.append(assigned_indices)
192
+
193
+ # It could be stacked in one tensor
194
+ matched_indices = [
195
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
196
+ ]
197
+ return matched_indices
198
+
199
+
200
+
201
+
202
+ def Mask2FormerMaskedAttentionDecoder_forward_first3layers(
203
+ self,
204
+ inputs_embeds: torch.Tensor = None,
205
+ multi_stage_positional_embeddings: torch.Tensor = None,
206
+ pixel_embeddings: torch.Tensor = None,
207
+ encoder_hidden_states: torch.Tensor = None,
208
+ query_position_embeddings: torch.Tensor = None,
209
+ feature_size_list: List = None,
210
+ output_attentions: Optional[bool] = None,
211
+ output_hidden_states: Optional[bool] = None,
212
+ return_dict: Optional[bool] = None,
213
+ ):
214
+ r"""
215
+ Args:
216
+ inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
217
+ The query embeddings that are passed into the decoder.
218
+ multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`):
219
+ Position embeddings that are added to the keys in each cross(masked)-attention layer.
220
+ pixel_embeddings (`torch.FloatTensor`):
221
+ Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel
222
+ Decoder.
223
+ query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
224
+ , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
225
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`):
226
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the
227
+ cross(masked)-attention of the decoder.
228
+ feature_size_list (`List[torch.Size]`):
229
+ This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder.
230
+ output_attentions (`bool`, *optional*):
231
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
232
+ returned tensors for more detail.
233
+ output_hidden_states (`bool`, *optional*):
234
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
235
+ for more detail.
236
+ return_dict (`bool`, *optional*):
237
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
238
+ """
239
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
240
+ output_hidden_states = (
241
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
242
+ )
243
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
244
+
245
+ if inputs_embeds is not None:
246
+ hidden_states = inputs_embeds
247
+
248
+ # intermediate hidden states with layernorm applied - required for predicting class logits
249
+ intermediate = ()
250
+
251
+ # decoder layers
252
+ all_hidden_states = () if output_hidden_states else None
253
+ attentions = () if output_attentions else None
254
+
255
+ # intermediate mask predictions from transformer decoder layers
256
+ intermediate_mask_predictions = ()
257
+
258
+ intermediate_hidden_states = self.layernorm(inputs_embeds)
259
+ intermediate += (intermediate_hidden_states,)
260
+
261
+ predicted_mask, attention_mask = self.mask_predictor(
262
+ intermediate_hidden_states, pixel_embeddings, feature_size_list[0]
263
+ )
264
+ intermediate_mask_predictions += (predicted_mask,)
265
+
266
+ for idx, decoder_layer in enumerate(self.layers[:3]):
267
+ if output_hidden_states:
268
+ all_hidden_states += (hidden_states,)
269
+
270
+ dropout_probability = torch.rand([])
271
+
272
+ if self.training and (dropout_probability < self.layerdrop):
273
+ continue
274
+
275
+ if self.gradient_checkpointing and self.training:
276
+ layer_outputs = self._gradient_checkpointing_func(
277
+ decoder_layer.__call__,
278
+ hidden_states,
279
+ attention_mask,
280
+ encoder_hidden_states,
281
+ None,
282
+ None,
283
+ output_attentions,
284
+ )
285
+
286
+ else:
287
+ level_index = idx % self.num_feature_levels
288
+
289
+ where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
290
+ # Multiply the attention mask instead of indexing to avoid issue in torch.export.
291
+ attention_mask = attention_mask * where.unsqueeze(-1)
292
+
293
+ layer_outputs = decoder_layer(
294
+ hidden_states,
295
+ level_index=level_index,
296
+ position_embeddings=multi_stage_positional_embeddings,
297
+ query_position_embeddings=query_position_embeddings,
298
+ encoder_hidden_states=encoder_hidden_states,
299
+ encoder_attention_mask=attention_mask,
300
+ output_attentions=output_attentions,
301
+ )
302
+
303
+ intermediate_hidden_states = self.layernorm(layer_outputs[0])
304
+
305
+ predicted_mask, attention_mask = self.mask_predictor(
306
+ intermediate_hidden_states,
307
+ pixel_embeddings,
308
+ feature_size_list[(idx + 1) % self.num_feature_levels],
309
+ )
310
+
311
+ intermediate_mask_predictions += (predicted_mask,)
312
+
313
+ # add intermediate hidden states with layer norm applied which will be used for predicting class logits
314
+ intermediate += (intermediate_hidden_states,)
315
+
316
+ hidden_states = layer_outputs[0]
317
+
318
+ if output_attentions:
319
+ attentions += (layer_outputs[1],)
320
+
321
+ # add hidden states from the last decoder layer
322
+ if output_hidden_states:
323
+ all_hidden_states += (hidden_states,)
324
+
325
+ hidden_states = hidden_states.transpose(1, 0)
326
+ if not return_dict:
327
+ outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions]
328
+ return tuple(v for v in outputs if v is not None)
329
+
330
+ return Mask2FormerMaskedAttentionDecoderOutput(
331
+ last_hidden_state=hidden_states,
332
+ hidden_states=all_hidden_states,
333
+ attentions=attentions,
334
+ intermediate_hidden_states=intermediate,
335
+ masks_queries_logits=intermediate_mask_predictions,
336
+ )
337
+
338
+
339
+ def Mask2FormerMaskedAttentionDecoder_forward_last3layers(
340
+ self,
341
+ inputs_embeds: torch.Tensor = None,
342
+ multi_stage_positional_embeddings: torch.Tensor = None,
343
+ pixel_embeddings: torch.Tensor = None,
344
+ encoder_hidden_states: torch.Tensor = None,
345
+ query_position_embeddings: torch.Tensor = None,
346
+ feature_size_list: List = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ ):
351
+ r"""
352
+ Args:
353
+ inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
354
+ The query embeddings that are passed into the decoder.
355
+ multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`):
356
+ Position embeddings that are added to the keys in each cross(masked)-attention layer.
357
+ pixel_embeddings (`torch.FloatTensor`):
358
+ Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel
359
+ Decoder.
360
+ query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
361
+ , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
362
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`):
363
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the
364
+ cross(masked)-attention of the decoder.
365
+ feature_size_list (`List[torch.Size]`):
366
+ This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder.
367
+ output_attentions (`bool`, *optional*):
368
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
369
+ returned tensors for more detail.
370
+ output_hidden_states (`bool`, *optional*):
371
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
372
+ for more detail.
373
+ return_dict (`bool`, *optional*):
374
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
375
+ """
376
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
377
+ output_hidden_states = (
378
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
379
+ )
380
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
381
+
382
+ if inputs_embeds is not None:
383
+ hidden_states = inputs_embeds
384
+
385
+ # intermediate hidden states with layernorm applied - required for predicting class logits
386
+ intermediate = ()
387
+
388
+ # decoder layers
389
+ all_hidden_states = () if output_hidden_states else None
390
+ attentions = () if output_attentions else None
391
+
392
+ # intermediate mask predictions from transformer decoder layers
393
+ intermediate_mask_predictions = ()
394
+
395
+ intermediate_hidden_states = self.layernorm(inputs_embeds)
396
+ intermediate += (intermediate_hidden_states,)
397
+
398
+ predicted_mask, attention_mask = self.mask_predictor(
399
+ intermediate_hidden_states, pixel_embeddings, feature_size_list[0]
400
+ )
401
+ intermediate_mask_predictions += (predicted_mask,)
402
+
403
+ for _idx, decoder_layer in enumerate(self.layers[3:]):
404
+ idx = _idx + 3
405
+ if output_hidden_states:
406
+ all_hidden_states += (hidden_states,)
407
+
408
+ dropout_probability = torch.rand([])
409
+
410
+ if self.training and (dropout_probability < self.layerdrop):
411
+ continue
412
+
413
+ if self.gradient_checkpointing and self.training:
414
+ layer_outputs = self._gradient_checkpointing_func(
415
+ decoder_layer.__call__,
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ None,
420
+ None,
421
+ output_attentions,
422
+ )
423
+
424
+ else:
425
+ level_index = idx % self.num_feature_levels
426
+
427
+ where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype)
428
+ # Multiply the attention mask instead of indexing to avoid issue in torch.export.
429
+ attention_mask = attention_mask * where.unsqueeze(-1)
430
+
431
+ layer_outputs = decoder_layer(
432
+ hidden_states,
433
+ level_index=level_index,
434
+ position_embeddings=multi_stage_positional_embeddings,
435
+ query_position_embeddings=query_position_embeddings,
436
+ encoder_hidden_states=encoder_hidden_states,
437
+ encoder_attention_mask=attention_mask,
438
+ output_attentions=output_attentions,
439
+ )
440
+
441
+ intermediate_hidden_states = self.layernorm(layer_outputs[0])
442
+
443
+ predicted_mask, attention_mask = self.mask_predictor(
444
+ intermediate_hidden_states,
445
+ pixel_embeddings,
446
+ feature_size_list[(idx + 1) % self.num_feature_levels],
447
+ )
448
+
449
+ intermediate_mask_predictions += (predicted_mask,)
450
+
451
+ # add intermediate hidden states with layer norm applied which will be used for predicting class logits
452
+ intermediate += (intermediate_hidden_states,)
453
+
454
+ hidden_states = layer_outputs[0]
455
+
456
+ if output_attentions:
457
+ attentions += (layer_outputs[1],)
458
+
459
+ # add hidden states from the last decoder layer
460
+ if output_hidden_states:
461
+ all_hidden_states += (hidden_states,)
462
+
463
+ hidden_states = hidden_states.transpose(1, 0)
464
+ if not return_dict:
465
+ outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions]
466
+ return tuple(v for v in outputs if v is not None)
467
+
468
+ return Mask2FormerMaskedAttentionDecoderOutput(
469
+ last_hidden_state=hidden_states,
470
+ hidden_states=all_hidden_states,
471
+ attentions=attentions,
472
+ intermediate_hidden_states=intermediate,
473
+ masks_queries_logits=intermediate_mask_predictions,
474
+ )
475
+
476
+
477
+ def Mask2FormerTransformerModule_forward_first_part(
478
+ self,
479
+ multi_scale_features: List[Tensor],
480
+ mask_features: Tensor,
481
+ output_hidden_states: bool = False,
482
+ output_attentions: bool = False,
483
+ ) -> Mask2FormerMaskedAttentionDecoderOutput:
484
+ multi_stage_features = []
485
+ multi_stage_positional_embeddings = []
486
+ size_list = []
487
+
488
+ for i in range(self.num_feature_levels):
489
+ size_list.append(multi_scale_features[i].shape[-2:])
490
+ multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
491
+ multi_stage_features.append(
492
+ self.input_projections[i](multi_scale_features[i]).flatten(2)
493
+ + self.level_embed.weight[i][None, :, None]
494
+ )
495
+
496
+ # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels)
497
+ multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
498
+ multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
499
+
500
+ _, batch_size, _ = multi_stage_features[0].shape
501
+
502
+ # [num_queries, batch_size, num_channels]
503
+ query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
504
+ query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
505
+
506
+ decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers(
507
+ inputs_embeds=query_features,
508
+ multi_stage_positional_embeddings=multi_stage_positional_embeddings,
509
+ pixel_embeddings=mask_features,
510
+ encoder_hidden_states=multi_stage_features,
511
+ query_position_embeddings=query_embeddings,
512
+ feature_size_list=size_list,
513
+ output_hidden_states=output_hidden_states,
514
+ output_attentions=output_attentions,
515
+ return_dict=True,
516
+ )
517
+
518
+ return decoder_output
519
+
520
+
521
+ def Mask2FormerTransformerModule_forward_second_part(
522
+ self,
523
+ query_features: Tensor,
524
+ query_embeddings: Tensor,
525
+ multi_scale_features: List[Tensor],
526
+ mask_features: Tensor,
527
+ output_hidden_states: bool = False,
528
+ output_attentions: bool = False,
529
+ ) -> Mask2FormerMaskedAttentionDecoderOutput:
530
+ multi_stage_features = []
531
+ multi_stage_positional_embeddings = []
532
+ size_list = []
533
+
534
+ for i in range(self.num_feature_levels):
535
+ size_list.append(multi_scale_features[i].shape[-2:])
536
+ multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
537
+ multi_stage_features.append(
538
+ self.input_projections[i](multi_scale_features[i]).flatten(2)
539
+ + self.level_embed.weight[i][None, :, None]
540
+ )
541
+
542
+ # Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels)
543
+ multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
544
+ multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
545
+
546
+ _, batch_size, _ = multi_stage_features[0].shape
547
+
548
+ # [num_queries, batch_size, num_channels]
549
+ # query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
550
+ # query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
551
+
552
+ decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers(
553
+ inputs_embeds=query_features,
554
+ multi_stage_positional_embeddings=multi_stage_positional_embeddings,
555
+ pixel_embeddings=mask_features,
556
+ encoder_hidden_states=multi_stage_features,
557
+ query_position_embeddings=query_embeddings,
558
+ feature_size_list=size_list,
559
+ output_hidden_states=output_hidden_states,
560
+ output_attentions=output_attentions,
561
+ return_dict=True,
562
+ )
563
+
564
+ return decoder_output
565
+
566
+
567
+ def Mask2FormerModel_forward_first_part(
568
+ self,
569
+ pixel_values: Tensor,
570
+ pixel_mask: Optional[Tensor] = None,
571
+ output_hidden_states: Optional[bool] = None,
572
+ output_attentions: Optional[bool] = None,
573
+ return_dict: Optional[bool] = None,
574
+ ) -> Mask2FormerModelOutput:
575
+ r"""
576
+ Returns:
577
+ `Mask2FormerModelOutput`
578
+
579
+ Examples:
580
+ ```python
581
+ >>> import torch
582
+ >>> from PIL import Image
583
+ >>> import requests
584
+ >>> from transformers import AutoImageProcessor, Mask2FormerModel
585
+
586
+ >>> # load image
587
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
588
+ >>> image = Image.open(requests.get(url, stream=True).raw)
589
+
590
+ >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset
591
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
592
+ >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance")
593
+ >>> inputs = image_processor(image, return_tensors="pt")
594
+
595
+ >>> # forward pass
596
+ >>> with torch.no_grad():
597
+ ... outputs = model(**inputs)
598
+
599
+ >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size)
600
+ >>> print(outputs.transformer_decoder_last_hidden_state.shape)
601
+ torch.Size([1, 100, 256])
602
+ ```
603
+ """
604
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
605
+ output_hidden_states = (
606
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
607
+ )
608
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
609
+
610
+ batch_size, _, height, width = pixel_values.shape
611
+
612
+ if pixel_mask is None:
613
+ pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
614
+
615
+ pixel_level_module_output = self.pixel_level_module(
616
+ pixel_values=pixel_values, output_hidden_states=output_hidden_states
617
+ )
618
+
619
+ transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_first_part(
620
+ multi_scale_features=pixel_level_module_output.decoder_hidden_states,
621
+ mask_features=pixel_level_module_output.decoder_last_hidden_state,
622
+ output_hidden_states=True,
623
+ output_attentions=output_attentions,
624
+ )
625
+
626
+ query_features = transformer_module_output.last_hidden_state
627
+ return query_features, pixel_level_module_output
628
+
629
+
630
+ def Mask2FormerModel_forward_second_part(
631
+ self,
632
+ query_features: Tensor,
633
+ query_embeddings: Tensor,
634
+ pixel_level_module_output,
635
+ pixel_values: Tensor,
636
+ pixel_mask: Optional[Tensor] = None,
637
+ output_hidden_states: Optional[bool] = None,
638
+ output_attentions: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None,
640
+ ) -> Mask2FormerModelOutput:
641
+ r"""
642
+ Returns:
643
+ `Mask2FormerModelOutput`
644
+
645
+ Examples:
646
+ ```python
647
+ >>> import torch
648
+ >>> from PIL import Image
649
+ >>> import requests
650
+ >>> from transformers import AutoImageProcessor, Mask2FormerModel
651
+
652
+ >>> # load image
653
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
654
+ >>> image = Image.open(requests.get(url, stream=True).raw)
655
+
656
+ >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset
657
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance")
658
+ >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance")
659
+ >>> inputs = image_processor(image, return_tensors="pt")
660
+
661
+ >>> # forward pass
662
+ >>> with torch.no_grad():
663
+ ... outputs = model(**inputs)
664
+
665
+ >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size)
666
+ >>> print(outputs.transformer_decoder_last_hidden_state.shape)
667
+ torch.Size([1, 100, 256])
668
+ ```
669
+ """
670
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
671
+ output_hidden_states = (
672
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
673
+ )
674
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
675
+
676
+ batch_size, _, height, width = pixel_values.shape
677
+
678
+ if pixel_mask is None:
679
+ pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
680
+
681
+ transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_second_part(
682
+ query_features=query_features,
683
+ query_embeddings=query_embeddings,
684
+ multi_scale_features=pixel_level_module_output.decoder_hidden_states,
685
+ mask_features=pixel_level_module_output.decoder_last_hidden_state,
686
+ output_hidden_states=True,
687
+ output_attentions=output_attentions,
688
+ )
689
+
690
+ encoder_hidden_states = None
691
+ pixel_decoder_hidden_states = None
692
+ transformer_decoder_hidden_states = None
693
+ transformer_decoder_intermediate_states = None
694
+
695
+ if output_hidden_states:
696
+ encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
697
+ pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
698
+ transformer_decoder_hidden_states = transformer_module_output.hidden_states
699
+ transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states
700
+
701
+ output = Mask2FormerModelOutput(
702
+ encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state,
703
+ pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state,
704
+ transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state,
705
+ encoder_hidden_states=encoder_hidden_states,
706
+ pixel_decoder_hidden_states=pixel_decoder_hidden_states,
707
+ transformer_decoder_hidden_states=transformer_decoder_hidden_states,
708
+ transformer_decoder_intermediate_states=transformer_decoder_intermediate_states,
709
+ attentions=transformer_module_output.attentions,
710
+ masks_queries_logits=transformer_module_output.masks_queries_logits,
711
+ )
712
+
713
+ if not return_dict:
714
+ output = tuple(v for v in output.values() if v is not None)
715
+
716
+ return output
717
+
718
+
719
+ def Mask2FormerForUniversalSegmentation_forward_first_part(
720
+ self,
721
+ pixel_values: Tensor,
722
+ mask_labels: Optional[List[Tensor]] = None,
723
+ class_labels: Optional[List[Tensor]] = None,
724
+ pixel_mask: Optional[Tensor] = None,
725
+ output_hidden_states: Optional[bool] = None,
726
+ output_auxiliary_logits: Optional[bool] = None,
727
+ output_attentions: Optional[bool] = None,
728
+ return_dict: Optional[bool] = None,
729
+ ) -> Mask2FormerForUniversalSegmentationOutput:
730
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
731
+ output_hidden_states = (
732
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
733
+ )
734
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
735
+
736
+ query_features, pixel_level_module_output = self.model.Mask2FormerModel_forward_first_part(
737
+ pixel_values=pixel_values,
738
+ pixel_mask=pixel_mask,
739
+ output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
740
+ output_attentions=output_attentions,
741
+ return_dict=True,
742
+ )
743
+
744
+ return query_features, pixel_level_module_output
745
+
746
+
747
+ def Mask2FormerForUniversalSegmentation_forward_second_part(
748
+ self,
749
+ query_features,
750
+ query_embeddings,
751
+ pixel_level_module_output,
752
+ text_classifier,
753
+ pixel_values: Tensor,
754
+ mask_labels: Optional[List[Tensor]] = None,
755
+ class_labels: Optional[List[Tensor]] = None,
756
+ pixel_mask: Optional[Tensor] = None,
757
+ output_hidden_states: Optional[bool] = None,
758
+ output_auxiliary_logits: Optional[bool] = None,
759
+ output_attentions: Optional[bool] = None,
760
+ return_dict: Optional[bool] = None,
761
+ ) -> Mask2FormerForUniversalSegmentationOutput:
762
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
763
+ output_hidden_states = (
764
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
765
+ )
766
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
767
+
768
+ outputs = self.model.Mask2FormerModel_forward_second_part(
769
+ query_features=query_features,
770
+ query_embeddings=query_embeddings,
771
+ pixel_level_module_output=pixel_level_module_output,
772
+ pixel_values=pixel_values,
773
+ pixel_mask=pixel_mask,
774
+ output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
775
+ output_attentions=output_attentions,
776
+ return_dict=True,
777
+ )
778
+
779
+ loss, loss_dict, auxiliary_logits = None, None, None
780
+ class_queries_logits = ()
781
+
782
+ for decoder_output in outputs.transformer_decoder_intermediate_states:
783
+ class_prediction = self.ov_class_predictor(decoder_output.transpose(0, 1), text_classifier)
784
+ # class_prediction = self.class_predictor(decoder_output.transpose(0, 1))
785
+ class_queries_logits += (class_prediction,)
786
+
787
+ masks_queries_logits = outputs.masks_queries_logits
788
+
789
+ auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits)
790
+
791
+ if mask_labels is not None and class_labels is not None:
792
+ loss_dict = self.get_loss_dict(
793
+ masks_queries_logits=masks_queries_logits[-1],
794
+ class_queries_logits=class_queries_logits[-1],
795
+ mask_labels=mask_labels,
796
+ class_labels=class_labels,
797
+ auxiliary_predictions=auxiliary_logits,
798
+ )
799
+ loss = self.get_loss(loss_dict)
800
+
801
+ encoder_hidden_states = None
802
+ pixel_decoder_hidden_states = None
803
+ transformer_decoder_hidden_states = None
804
+
805
+ if output_hidden_states:
806
+ encoder_hidden_states = outputs.encoder_hidden_states
807
+ pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states
808
+ transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states
809
+
810
+ output_auxiliary_logits = (
811
+ self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits
812
+ )
813
+ if not output_auxiliary_logits:
814
+ auxiliary_logits = None
815
+
816
+ output = Mask2FormerForUniversalSegmentationOutput(
817
+ loss=loss,
818
+ class_queries_logits=class_queries_logits[-1],
819
+ masks_queries_logits=masks_queries_logits[-1],
820
+ auxiliary_logits=auxiliary_logits,
821
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
822
+ pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state,
823
+ transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state,
824
+ encoder_hidden_states=encoder_hidden_states,
825
+ pixel_decoder_hidden_states=pixel_decoder_hidden_states,
826
+ transformer_decoder_hidden_states=transformer_decoder_hidden_states,
827
+ attentions=outputs.attentions,
828
+ )
829
+
830
+ if not return_dict:
831
+ output = tuple(v for v in output.values() if v is not None)
832
+ if loss is not None:
833
+ output = (loss) + output
834
+ return output
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_intern_vit.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from einops import rearrange
13
+ from timm.models.layers import DropPath
14
+ from torch import nn
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import (BaseModelOutput,
17
+ BaseModelOutputWithPooling)
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ from .configuration_intern_vit import InternVisionConfig
22
+
23
+ try:
24
+ from .flash_attention import FlashAttention
25
+ has_flash_attn = True
26
+ except:
27
+ print('FlashAttention is not installed.')
28
+ has_flash_attn = False
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53
+ except ImportError:
54
+ # using the normal InternRMSNorm
55
+ pass
56
+ except Exception:
57
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58
+ pass
59
+
60
+
61
+ NORM2FN = {
62
+ 'rms_norm': InternRMSNorm,
63
+ 'layer_norm': nn.LayerNorm,
64
+ }
65
+
66
+
67
+ class InternVisionEmbeddings(nn.Module):
68
+ def __init__(self, config: InternVisionConfig):
69
+ super().__init__()
70
+ self.config = config
71
+ self.embed_dim = config.hidden_size
72
+ self.image_size = config.image_size
73
+ self.patch_size = config.patch_size
74
+
75
+ self.class_embedding = nn.Parameter(
76
+ torch.randn(1, 1, self.embed_dim),
77
+ )
78
+
79
+ self.patch_embedding = nn.Conv2d(
80
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
81
+ )
82
+
83
+ self.num_patches = (self.image_size // self.patch_size) ** 2
84
+ self.num_positions = self.num_patches + 1
85
+
86
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
87
+
88
+ def _get_pos_embed(self, pos_embed, H, W):
89
+ target_dtype = pos_embed.dtype
90
+ pos_embed = pos_embed.float().reshape(
91
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
92
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
93
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
94
+ return pos_embed
95
+
96
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
97
+ target_dtype = self.patch_embedding.weight.dtype
98
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
99
+ batch_size, _, height, width = patch_embeds.shape
100
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
101
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
102
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
103
+ position_embedding = torch.cat([
104
+ self.position_embedding[:, :1, :],
105
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
106
+ ], dim=1)
107
+ embeddings = embeddings + position_embedding.to(target_dtype)
108
+ return embeddings
109
+
110
+
111
+ class InternAttention(nn.Module):
112
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
113
+
114
+ def __init__(self, config: InternVisionConfig):
115
+ super().__init__()
116
+ self.config = config
117
+ self.embed_dim = config.hidden_size
118
+ self.num_heads = config.num_attention_heads
119
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
120
+ if config.use_flash_attn and not has_flash_attn:
121
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
122
+ self.head_dim = self.embed_dim // self.num_heads
123
+ if self.head_dim * self.num_heads != self.embed_dim:
124
+ raise ValueError(
125
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
126
+ f' {self.num_heads}).'
127
+ )
128
+
129
+ self.scale = self.head_dim ** -0.5
130
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
131
+ self.attn_drop = nn.Dropout(config.attention_dropout)
132
+ self.proj_drop = nn.Dropout(config.dropout)
133
+
134
+ self.qk_normalization = config.qk_normalization
135
+
136
+ if self.qk_normalization:
137
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
138
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
139
+
140
+ if self.use_flash_attn:
141
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
142
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
143
+
144
+ def _naive_attn(self, x):
145
+ B, N, C = x.shape
146
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
147
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
148
+
149
+ if self.qk_normalization:
150
+ B_, H_, N_, D_ = q.shape
151
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
152
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
153
+
154
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
155
+ attn = attn.softmax(dim=-1)
156
+ attn = self.attn_drop(attn)
157
+
158
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159
+ x = self.proj(x)
160
+ x = self.proj_drop(x)
161
+ return x
162
+
163
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
164
+ qkv = self.qkv(x)
165
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
166
+
167
+ if self.qk_normalization:
168
+ q, k, v = qkv.unbind(2)
169
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
170
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
171
+ qkv = torch.stack([q, k, v], dim=2)
172
+
173
+ context, _ = self.inner_attn(
174
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
175
+ )
176
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
177
+ outs = self.proj_drop(outs)
178
+ return outs
179
+
180
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
182
+ return x
183
+
184
+
185
+ class InternMLP(nn.Module):
186
+ def __init__(self, config: InternVisionConfig):
187
+ super().__init__()
188
+ self.config = config
189
+ self.act = ACT2FN[config.hidden_act]
190
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
191
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
192
+
193
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194
+ hidden_states = self.fc1(hidden_states)
195
+ hidden_states = self.act(hidden_states)
196
+ hidden_states = self.fc2(hidden_states)
197
+ return hidden_states
198
+
199
+
200
+ class InternVisionEncoderLayer(nn.Module):
201
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
202
+ super().__init__()
203
+ self.embed_dim = config.hidden_size
204
+ self.intermediate_size = config.intermediate_size
205
+ self.norm_type = config.norm_type
206
+
207
+ self.attn = InternAttention(config)
208
+ self.mlp = InternMLP(config)
209
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
210
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
211
+
212
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
213
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
214
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
215
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.Tensor,
220
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
221
+ """
222
+ Args:
223
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
224
+ """
225
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
226
+
227
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
228
+
229
+ return hidden_states
230
+
231
+
232
+ class InternVisionEncoder(nn.Module):
233
+ """
234
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
235
+ [`InternEncoderLayer`].
236
+
237
+ Args:
238
+ config (`InternConfig`):
239
+ The corresponding vision configuration for the `InternEncoder`.
240
+ """
241
+
242
+ def __init__(self, config: InternVisionConfig):
243
+ super().__init__()
244
+ self.config = config
245
+ # stochastic depth decay rule
246
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
247
+ self.layers = nn.ModuleList([
248
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
249
+ self.gradient_checkpointing = True
250
+
251
+ def forward(
252
+ self,
253
+ inputs_embeds,
254
+ output_hidden_states: Optional[bool] = None,
255
+ return_dict: Optional[bool] = None,
256
+ ) -> Union[Tuple, BaseModelOutput]:
257
+ r"""
258
+ Args:
259
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
260
+ Embedded representation of the inputs. Should be float, not int tokens.
261
+ output_hidden_states (`bool`, *optional*):
262
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
263
+ for more detail.
264
+ return_dict (`bool`, *optional*):
265
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
266
+ """
267
+ output_hidden_states = (
268
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269
+ )
270
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271
+
272
+ encoder_states = () if output_hidden_states else None
273
+ hidden_states = inputs_embeds
274
+
275
+ for idx, encoder_layer in enumerate(self.layers):
276
+ if output_hidden_states:
277
+ encoder_states = encoder_states + (hidden_states,)
278
+ if self.gradient_checkpointing and self.training:
279
+ layer_outputs = torch.utils.checkpoint.checkpoint(
280
+ encoder_layer,
281
+ hidden_states)
282
+ else:
283
+ layer_outputs = encoder_layer(
284
+ hidden_states,
285
+ )
286
+ hidden_states = layer_outputs
287
+
288
+ if output_hidden_states:
289
+ encoder_states = encoder_states + (hidden_states,)
290
+
291
+ if not return_dict:
292
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
293
+ return BaseModelOutput(
294
+ last_hidden_state=hidden_states, hidden_states=encoder_states
295
+ )
296
+
297
+
298
+ class InternVisionModel(PreTrainedModel):
299
+ main_input_name = 'pixel_values'
300
+ _supports_flash_attn_2 = True
301
+ config_class = InternVisionConfig
302
+ _no_split_modules = ['InternVisionEncoderLayer']
303
+
304
+ def __init__(self, config: InternVisionConfig):
305
+ super().__init__(config)
306
+ self.config = config
307
+
308
+ self.embeddings = InternVisionEmbeddings(config)
309
+ self.encoder = InternVisionEncoder(config)
310
+
311
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
312
+ pos_emb = self.embeddings.position_embedding
313
+ _, num_positions, embed_dim = pos_emb.shape
314
+ cls_emb = pos_emb[:, :1, :]
315
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
316
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
317
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
318
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
319
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
320
+ self.embeddings.image_size = new_size
321
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
322
+
323
+ def get_input_embeddings(self):
324
+ return self.embeddings
325
+
326
+ def forward(
327
+ self,
328
+ pixel_values: Optional[torch.FloatTensor] = None,
329
+ output_hidden_states: Optional[bool] = None,
330
+ return_dict: Optional[bool] = None,
331
+ pixel_embeds: Optional[torch.FloatTensor] = None,
332
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
333
+ output_hidden_states = (
334
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
335
+ )
336
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
337
+
338
+ if pixel_values is None and pixel_embeds is None:
339
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
340
+
341
+ if pixel_embeds is not None:
342
+ hidden_states = pixel_embeds
343
+ else:
344
+ if len(pixel_values.shape) == 4:
345
+ hidden_states = self.embeddings(pixel_values)
346
+ else:
347
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
348
+ encoder_outputs = self.encoder(
349
+ inputs_embeds=hidden_states,
350
+ output_hidden_states=output_hidden_states,
351
+ return_dict=return_dict,
352
+ )
353
+ last_hidden_state = encoder_outputs.last_hidden_state
354
+ pooled_output = last_hidden_state[:, 0, :]
355
+
356
+ if not return_dict:
357
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
358
+
359
+ return BaseModelOutputWithPooling(
360
+ last_hidden_state=last_hidden_state,
361
+ pooler_output=pooled_output,
362
+ hidden_states=encoder_outputs.hidden_states,
363
+ attentions=encoder_outputs.attentions,
364
+ )
modeling_internlm2.py ADDED
@@ -0,0 +1,1429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch InternLM2 model."""
17
+ import math
18
+ import queue
19
+ import threading
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from einops import rearrange
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
31
+ CausalLMOutputWithPast,
32
+ SequenceClassifierOutputWithPast)
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import (add_start_docstrings,
35
+ add_start_docstrings_to_model_forward, logging,
36
+ replace_return_docstrings)
37
+
38
+ try:
39
+ from transformers.generation.streamers import BaseStreamer
40
+ except: # noqa # pylint: disable=bare-except
41
+ BaseStreamer = None
42
+
43
+ from .configuration_internlm2 import InternLM2Config
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _CONFIG_FOR_DOC = 'InternLM2Config'
48
+
49
+ flash_attn_func, flash_attn_varlen_func = None, None
50
+ pad_input, index_first_axis, unpad_input = None, None, None
51
+ try:
52
+ from flash_attn import flash_attn_func as _flash_attn_func
53
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
54
+ from flash_attn.bert_padding import index_first_axis as _index_first_axis
55
+ from flash_attn.bert_padding import pad_input as _pad_input
56
+ from flash_attn.bert_padding import unpad_input as _unpad_input
57
+
58
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
59
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
60
+ has_flash_attn = True
61
+ except:
62
+ has_flash_attn = False
63
+
64
+
65
+ def _import_flash_attn():
66
+ global flash_attn_func, flash_attn_varlen_func
67
+ global pad_input, index_first_axis, unpad_input
68
+ try:
69
+ from flash_attn import flash_attn_func as _flash_attn_func
70
+ from flash_attn import \
71
+ flash_attn_varlen_func as _flash_attn_varlen_func
72
+ from flash_attn.bert_padding import \
73
+ index_first_axis as _index_first_axis
74
+ from flash_attn.bert_padding import pad_input as _pad_input
75
+ from flash_attn.bert_padding import unpad_input as _unpad_input
76
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
77
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
78
+ except ImportError:
79
+ raise ImportError('flash_attn is not installed.')
80
+
81
+
82
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
83
+ def _get_unpad_data(attention_mask):
84
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
85
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
86
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
87
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
88
+ return (
89
+ indices,
90
+ cu_seqlens,
91
+ max_seqlen_in_batch,
92
+ )
93
+
94
+
95
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
96
+ def _make_causal_mask(
97
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
98
+ ):
99
+ """
100
+ Make causal mask used for bi-directional self-attention.
101
+ """
102
+ bsz, tgt_len = input_ids_shape
103
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
104
+ mask_cond = torch.arange(mask.size(-1), device=device)
105
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
106
+ mask = mask.to(dtype)
107
+
108
+ if past_key_values_length > 0:
109
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
110
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
111
+
112
+
113
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
114
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
115
+ """
116
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
117
+ """
118
+ bsz, src_len = mask.size()
119
+ tgt_len = tgt_len if tgt_len is not None else src_len
120
+
121
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
122
+
123
+ inverted_mask = 1.0 - expanded_mask
124
+
125
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
126
+
127
+
128
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
129
+ class InternLM2RMSNorm(nn.Module):
130
+ def __init__(self, hidden_size, eps=1e-6):
131
+ """
132
+ InternLM2RMSNorm is equivalent to T5LayerNorm
133
+ """
134
+ super().__init__()
135
+ self.weight = nn.Parameter(torch.ones(hidden_size))
136
+ self.variance_epsilon = eps
137
+
138
+ def forward(self, hidden_states):
139
+ input_dtype = hidden_states.dtype
140
+ hidden_states = hidden_states.to(torch.float32)
141
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
142
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
143
+ return self.weight * hidden_states.to(input_dtype)
144
+
145
+
146
+ try:
147
+ from functools import partial
148
+
149
+ from apex.normalization import FusedRMSNorm
150
+ InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
151
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternLM2RMSNorm')
152
+ except ImportError:
153
+ # using the normal LlamaRMSNorm
154
+ pass
155
+ except Exception:
156
+ print('discovered apex but it failed to load, falling back to InternLM2RMSNorm')
157
+ pass
158
+
159
+
160
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
161
+ class InternLM2RotaryEmbedding(nn.Module):
162
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
163
+ super().__init__()
164
+
165
+ self.dim = dim
166
+ self.max_position_embeddings = max_position_embeddings
167
+ self.base = base
168
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
169
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
170
+
171
+ # Build here to make `torch.jit.trace` work.
172
+ self._set_cos_sin_cache(
173
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
174
+ )
175
+
176
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
177
+ self.max_seq_len_cached = seq_len
178
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
179
+
180
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
181
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
182
+ emb = torch.cat((freqs, freqs), dim=-1)
183
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
184
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
185
+
186
+ def forward(self, x, seq_len=None):
187
+ # x: [bs, num_attention_heads, seq_len, head_size]
188
+ if seq_len > self.max_seq_len_cached:
189
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
190
+
191
+ return (
192
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
193
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
194
+ )
195
+
196
+
197
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
198
+ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
199
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
200
+
201
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
202
+ self.scaling_factor = scaling_factor
203
+ super().__init__(dim, max_position_embeddings, base, device)
204
+
205
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
206
+ self.max_seq_len_cached = seq_len
207
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
208
+ t = t / self.scaling_factor
209
+
210
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
211
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
212
+ emb = torch.cat((freqs, freqs), dim=-1)
213
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
214
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
215
+
216
+
217
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
218
+ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
219
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
220
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
221
+ """
222
+
223
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
224
+ self.scaling_factor = scaling_factor
225
+ super().__init__(dim, max_position_embeddings, base, device)
226
+
227
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
228
+ self.max_seq_len_cached = seq_len
229
+
230
+ if seq_len > self.max_position_embeddings:
231
+ base = self.base * (
232
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
233
+ ) ** (self.dim / (self.dim - 2))
234
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
235
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
236
+
237
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
238
+
239
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
240
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
241
+ emb = torch.cat((freqs, freqs), dim=-1)
242
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
243
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
244
+
245
+
246
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
247
+ def rotate_half(x):
248
+ """Rotates half the hidden dims of the input."""
249
+ x1 = x[..., : x.shape[-1] // 2]
250
+ x2 = x[..., x.shape[-1] // 2:]
251
+ return torch.cat((-x2, x1), dim=-1)
252
+
253
+
254
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
255
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
256
+ """Applies Rotary Position Embedding to the query and key tensors."""
257
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
258
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
259
+ q_embed = (q * cos) + (rotate_half(q) * sin)
260
+ k_embed = (k * cos) + (rotate_half(k) * sin)
261
+ return q_embed, k_embed
262
+
263
+
264
+ class InternLM2MLP(nn.Module):
265
+ def __init__(self, config):
266
+ super().__init__()
267
+ self.config = config
268
+ self.hidden_size = config.hidden_size
269
+ self.intermediate_size = config.intermediate_size
270
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
271
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
272
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
273
+ self.act_fn = ACT2FN[config.hidden_act]
274
+
275
+ def forward(self, x):
276
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
277
+
278
+ return down_proj
279
+
280
+
281
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
282
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
283
+ """
284
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
285
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
286
+ """
287
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
288
+ if n_rep == 1:
289
+ return hidden_states
290
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
291
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
292
+
293
+
294
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
295
+ class InternLM2Attention(nn.Module):
296
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
297
+
298
+ def __init__(self, config: InternLM2Config):
299
+ super().__init__()
300
+ self.config = config
301
+ self.hidden_size = config.hidden_size
302
+ self.num_heads = config.num_attention_heads
303
+ self.head_dim = self.hidden_size // self.num_heads
304
+ self.num_key_value_heads = config.num_key_value_heads
305
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
306
+ self.max_position_embeddings = config.max_position_embeddings
307
+ self.is_causal = True
308
+
309
+ if (self.head_dim * self.num_heads) != self.hidden_size:
310
+ raise ValueError(
311
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
312
+ f' and `num_heads`: {self.num_heads}).'
313
+ )
314
+
315
+ self.wqkv = nn.Linear(
316
+ self.hidden_size,
317
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
318
+ bias=config.bias,
319
+ )
320
+
321
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
322
+ self._init_rope()
323
+
324
+ def _init_rope(self):
325
+ if self.config.rope_scaling is None:
326
+ self.rotary_emb = InternLM2RotaryEmbedding(
327
+ self.head_dim,
328
+ max_position_embeddings=self.max_position_embeddings,
329
+ base=self.config.rope_theta,
330
+ )
331
+ else:
332
+ scaling_type = self.config.rope_scaling['type']
333
+ scaling_factor = self.config.rope_scaling['factor']
334
+ if scaling_type == 'dynamic':
335
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
336
+ self.head_dim,
337
+ max_position_embeddings=self.max_position_embeddings,
338
+ base=self.config.rope_theta,
339
+ scaling_factor=scaling_factor,
340
+ )
341
+ elif scaling_type == 'linear':
342
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
343
+ self.head_dim,
344
+ max_position_embeddings=self.max_position_embeddings,
345
+ base=self.config.rope_theta,
346
+ scaling_factor=scaling_factor,
347
+ )
348
+ else:
349
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
350
+ return self.rotary_emb
351
+
352
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
353
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ position_ids: Optional[torch.LongTensor] = None,
360
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
361
+ output_attentions: bool = False,
362
+ use_cache: bool = False,
363
+ **kwargs,
364
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
365
+ if 'padding_mask' in kwargs:
366
+ warnings.warn(
367
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
368
+ 'Please make sure use `attention_mask` instead.`'
369
+ )
370
+
371
+ bsz, q_len, _ = hidden_states.size()
372
+
373
+ qkv_states = self.wqkv(hidden_states)
374
+
375
+ qkv_states = rearrange(
376
+ qkv_states,
377
+ 'b q (h gs d) -> b q h gs d',
378
+ gs=2 + self.num_key_value_groups,
379
+ d=self.head_dim,
380
+ )
381
+
382
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
383
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
384
+ key_states = qkv_states[..., -2, :]
385
+ value_states = qkv_states[..., -1, :]
386
+
387
+ query_states = query_states.transpose(1, 2)
388
+ key_states = key_states.transpose(1, 2)
389
+ value_states = value_states.transpose(1, 2)
390
+
391
+ kv_seq_len = key_states.shape[-2]
392
+ if past_key_value is not None:
393
+ kv_seq_len += past_key_value[0].shape[-2]
394
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
395
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
396
+
397
+ if past_key_value is not None:
398
+ # reuse k, v, self_attention
399
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
400
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
401
+
402
+ past_key_value = (key_states, value_states) if use_cache else None
403
+
404
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
405
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
406
+
407
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
408
+
409
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
410
+ raise ValueError(
411
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
412
+ f' {attn_weights.size()}'
413
+ )
414
+
415
+ if attention_mask is not None:
416
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
417
+ raise ValueError(
418
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
419
+ )
420
+ attn_weights = attn_weights + attention_mask
421
+
422
+ # upcast attention to fp32
423
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
424
+ attn_output = torch.matmul(attn_weights, value_states)
425
+
426
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
427
+ raise ValueError(
428
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
429
+ f' {attn_output.size()}'
430
+ )
431
+
432
+ attn_output = attn_output.transpose(1, 2).contiguous()
433
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
434
+
435
+ attn_output = self.wo(attn_output)
436
+
437
+ if not output_attentions:
438
+ attn_weights = None
439
+
440
+ return attn_output, attn_weights, past_key_value
441
+
442
+
443
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
444
+ class InternLM2FlashAttention2(InternLM2Attention):
445
+ """
446
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
447
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
448
+ flash attention and deal with padding tokens in case the input contains any of them.
449
+ """
450
+
451
+ def forward(
452
+ self,
453
+ hidden_states: torch.Tensor,
454
+ attention_mask: Optional[torch.LongTensor] = None,
455
+ position_ids: Optional[torch.LongTensor] = None,
456
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
457
+ output_attentions: bool = False,
458
+ use_cache: bool = False,
459
+ **kwargs,
460
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
461
+ # InternLM2FlashAttention2 attention does not support output_attentions
462
+ if 'padding_mask' in kwargs:
463
+ warnings.warn(
464
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
465
+ 'Please make sure use `attention_mask` instead.`'
466
+ )
467
+
468
+ # overwrite attention_mask with padding_mask
469
+ attention_mask = kwargs.pop('padding_mask')
470
+
471
+ output_attentions = False
472
+
473
+ bsz, q_len, _ = hidden_states.size()
474
+
475
+ qkv_states = self.wqkv(hidden_states)
476
+
477
+ qkv_states = rearrange(
478
+ qkv_states,
479
+ 'b q (h gs d) -> b q h gs d',
480
+ gs=2 + self.num_key_value_groups,
481
+ d=self.head_dim,
482
+ )
483
+
484
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
485
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
486
+ key_states = qkv_states[..., -2, :]
487
+ value_states = qkv_states[..., -1, :]
488
+
489
+ query_states = query_states.transpose(1, 2)
490
+ key_states = key_states.transpose(1, 2)
491
+ value_states = value_states.transpose(1, 2)
492
+
493
+ kv_seq_len = key_states.shape[-2]
494
+ if past_key_value is not None:
495
+ kv_seq_len += past_key_value[0].shape[-2]
496
+
497
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
498
+
499
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
500
+
501
+ if past_key_value is not None:
502
+ # reuse k, v, self_attention
503
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
504
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
505
+
506
+ past_key_value = (key_states, value_states) if use_cache else None
507
+
508
+ query_states = query_states.transpose(1, 2)
509
+ key_states = key_states.transpose(1, 2)
510
+ value_states = value_states.transpose(1, 2)
511
+
512
+ attn_output = self._flash_attention_forward(
513
+ query_states, key_states, value_states, attention_mask, q_len
514
+ )
515
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
516
+ attn_output = self.wo(attn_output)
517
+
518
+ if not output_attentions:
519
+ attn_weights = None
520
+
521
+ return attn_output, attn_weights, past_key_value
522
+
523
+ def _flash_attention_forward(
524
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
525
+ ):
526
+ """
527
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
528
+ first unpad the input, then computes the attention scores and pad the final attention scores.
529
+
530
+ Args:
531
+ query_states (`torch.Tensor`):
532
+ Input query states to be passed to Flash Attention API
533
+ key_states (`torch.Tensor`):
534
+ Input key states to be passed to Flash Attention API
535
+ value_states (`torch.Tensor`):
536
+ Input value states to be passed to Flash Attention API
537
+ attention_mask (`torch.Tensor`):
538
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
539
+ position of padding tokens and 1 for the position of non-padding tokens.
540
+ dropout (`int`, *optional*):
541
+ Attention dropout
542
+ softmax_scale (`float`, *optional*):
543
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
544
+ """
545
+ # Contains at least one padding token in the sequence
546
+ causal = self.is_causal and query_length != 1
547
+ if attention_mask is not None:
548
+ batch_size = query_states.shape[0]
549
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
550
+ query_states, key_states, value_states, attention_mask, query_length
551
+ )
552
+
553
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
554
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
555
+
556
+ attn_output_unpad = flash_attn_varlen_func(
557
+ query_states,
558
+ key_states,
559
+ value_states,
560
+ cu_seqlens_q=cu_seqlens_q,
561
+ cu_seqlens_k=cu_seqlens_k,
562
+ max_seqlen_q=max_seqlen_in_batch_q,
563
+ max_seqlen_k=max_seqlen_in_batch_k,
564
+ dropout_p=dropout,
565
+ softmax_scale=softmax_scale,
566
+ causal=causal,
567
+ )
568
+
569
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
570
+ else:
571
+ attn_output = flash_attn_func(
572
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
573
+ )
574
+
575
+ return attn_output
576
+
577
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
578
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
579
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
580
+
581
+ key_layer = index_first_axis(
582
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
+ )
584
+ value_layer = index_first_axis(
585
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
586
+ )
587
+
588
+ if query_length == kv_seq_len:
589
+ query_layer = index_first_axis(
590
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
591
+ )
592
+ cu_seqlens_q = cu_seqlens_k
593
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
594
+ indices_q = indices_k
595
+ elif query_length == 1:
596
+ max_seqlen_in_batch_q = 1
597
+ cu_seqlens_q = torch.arange(
598
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
599
+ ) # There is a memcpy here, that is very bad.
600
+ indices_q = cu_seqlens_q[:-1]
601
+ query_layer = query_layer.squeeze(1)
602
+ else:
603
+ # The -q_len: slice assumes left padding.
604
+ attention_mask = attention_mask[:, -query_length:]
605
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
606
+
607
+ return (
608
+ query_layer,
609
+ key_layer,
610
+ value_layer,
611
+ indices_q.to(torch.int64),
612
+ (cu_seqlens_q, cu_seqlens_k),
613
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
614
+ )
615
+
616
+
617
+ INTERNLM2_ATTENTION_CLASSES = {
618
+ 'eager': InternLM2Attention,
619
+ 'flash_attention_2': InternLM2FlashAttention2,
620
+ }
621
+
622
+
623
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
624
+ class InternLM2DecoderLayer(nn.Module):
625
+ def __init__(self, config: InternLM2Config):
626
+ super().__init__()
627
+ self.hidden_size = config.hidden_size
628
+
629
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
630
+
631
+ self.feed_forward = InternLM2MLP(config)
632
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
633
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ attention_mask: Optional[torch.Tensor] = None,
639
+ position_ids: Optional[torch.LongTensor] = None,
640
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
641
+ output_attentions: Optional[bool] = False,
642
+ use_cache: Optional[bool] = False,
643
+ **kwargs,
644
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
645
+ """
646
+ Args:
647
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
648
+ attention_mask (`torch.FloatTensor`, *optional*):
649
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
650
+ query_sequence_length, key_sequence_length)` if default attention is used.
651
+ output_attentions (`bool`, *optional*):
652
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
653
+ returned tensors for more detail.
654
+ use_cache (`bool`, *optional*):
655
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
656
+ (see `past_key_values`).
657
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
658
+ """
659
+ if 'padding_mask' in kwargs:
660
+ warnings.warn(
661
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
662
+ 'Please make sure use `attention_mask` instead.`'
663
+ )
664
+
665
+ residual = hidden_states
666
+
667
+ hidden_states = self.attention_norm(hidden_states)
668
+
669
+ # Self Attention
670
+ hidden_states, self_attn_weights, present_key_value = self.attention(
671
+ hidden_states=hidden_states,
672
+ attention_mask=attention_mask,
673
+ position_ids=position_ids,
674
+ past_key_value=past_key_value,
675
+ output_attentions=output_attentions,
676
+ use_cache=use_cache,
677
+ **kwargs,
678
+ )
679
+ hidden_states = residual + hidden_states
680
+
681
+ # Fully Connected
682
+ residual = hidden_states
683
+ hidden_states = self.ffn_norm(hidden_states)
684
+ hidden_states = self.feed_forward(hidden_states)
685
+ hidden_states = residual + hidden_states
686
+
687
+ outputs = (hidden_states,)
688
+
689
+ if output_attentions:
690
+ outputs += (self_attn_weights,)
691
+
692
+ if use_cache:
693
+ outputs += (present_key_value,)
694
+
695
+ return outputs
696
+
697
+
698
+ InternLM2_START_DOCSTRING = r"""
699
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
700
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
701
+ etc.)
702
+
703
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
704
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
705
+ and behavior.
706
+
707
+ Parameters:
708
+ config ([`InternLM2Config`]):
709
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
710
+ load the weights associated with the model, only the configuration. Check out the
711
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
712
+ """
713
+
714
+
715
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
716
+ @add_start_docstrings(
717
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
718
+ InternLM2_START_DOCSTRING,
719
+ )
720
+ class InternLM2PreTrainedModel(PreTrainedModel):
721
+ config_class = InternLM2Config
722
+ base_model_prefix = 'model'
723
+ supports_gradient_checkpointing = True
724
+ _no_split_modules = ['InternLM2DecoderLayer']
725
+ _skip_keys_device_placement = 'past_key_values'
726
+ _supports_flash_attn_2 = True
727
+
728
+ def _init_weights(self, module):
729
+ std = self.config.initializer_range
730
+ if isinstance(module, nn.Linear):
731
+ module.weight.data.normal_(mean=0.0, std=std)
732
+ if module.bias is not None:
733
+ module.bias.data.zero_()
734
+ elif isinstance(module, nn.Embedding):
735
+ module.weight.data.normal_(mean=0.0, std=std)
736
+ if module.padding_idx is not None:
737
+ module.weight.data[module.padding_idx].zero_()
738
+
739
+
740
+ InternLM2_INPUTS_DOCSTRING = r"""
741
+ Args:
742
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
743
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
744
+ it.
745
+
746
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
747
+ [`PreTrainedTokenizer.__call__`] for details.
748
+
749
+ [What are input IDs?](../glossary#input-ids)
750
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
751
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
752
+
753
+ - 1 for tokens that are **not masked**,
754
+ - 0 for tokens that are **masked**.
755
+
756
+ [What are attention masks?](../glossary#attention-mask)
757
+
758
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
759
+ [`PreTrainedTokenizer.__call__`] for details.
760
+
761
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
762
+ `past_key_values`).
763
+
764
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
765
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
766
+ information on the default strategy.
767
+
768
+ - 1 indicates the head is **not masked**,
769
+ - 0 indicates the head is **masked**.
770
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
771
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
772
+ config.n_positions - 1]`.
773
+
774
+ [What are position IDs?](../glossary#position-ids)
775
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
776
+ when `config.use_cache=True`):
777
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
778
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
779
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
780
+
781
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
782
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
783
+
784
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
785
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
786
+ of shape `(batch_size, sequence_length)`.
787
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
788
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
789
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
790
+ model's internal embedding lookup matrix.
791
+ use_cache (`bool`, *optional*):
792
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
793
+ `past_key_values`).
794
+ output_attentions (`bool`, *optional*):
795
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
796
+ tensors for more detail.
797
+ output_hidden_states (`bool`, *optional*):
798
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
799
+ more detail.
800
+ return_dict (`bool`, *optional*):
801
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
802
+ """
803
+
804
+
805
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
806
+ @add_start_docstrings(
807
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
808
+ InternLM2_START_DOCSTRING,
809
+ )
810
+ class InternLM2Model(InternLM2PreTrainedModel):
811
+ """
812
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
813
+
814
+ Args:
815
+ config: InternLM2Config
816
+ """
817
+
818
+ _auto_class = 'AutoModel'
819
+
820
+ def __init__(self, config: InternLM2Config):
821
+ super().__init__(config)
822
+ self.padding_idx = config.pad_token_id
823
+ self.vocab_size = config.vocab_size
824
+ self.config = config
825
+ if not has_flash_attn:
826
+ self.config.attn_implementation = 'eager'
827
+ print('Warning: Flash attention is not available, using eager attention instead.')
828
+
829
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
830
+
831
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
832
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
833
+
834
+ self.gradient_checkpointing = False
835
+ # Initialize weights and apply final processing
836
+ self.post_init()
837
+
838
+ def get_input_embeddings(self):
839
+ return self.tok_embeddings
840
+
841
+ def set_input_embeddings(self, value):
842
+ self.tok_embeddings = value
843
+
844
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
845
+ # create causal mask
846
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
847
+ combined_attention_mask = None
848
+ if input_shape[-1] > 1:
849
+ combined_attention_mask = _make_causal_mask(
850
+ input_shape,
851
+ inputs_embeds.dtype,
852
+ device=inputs_embeds.device,
853
+ past_key_values_length=past_key_values_length,
854
+ )
855
+
856
+ if attention_mask is not None:
857
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
858
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
859
+ inputs_embeds.device
860
+ )
861
+ combined_attention_mask = (
862
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
863
+ )
864
+
865
+ return combined_attention_mask
866
+
867
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
868
+ def forward(
869
+ self,
870
+ input_ids: torch.LongTensor = None,
871
+ attention_mask: Optional[torch.Tensor] = None,
872
+ position_ids: Optional[torch.LongTensor] = None,
873
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
874
+ inputs_embeds: Optional[torch.FloatTensor] = None,
875
+ use_cache: Optional[bool] = None,
876
+ output_attentions: Optional[bool] = None,
877
+ output_hidden_states: Optional[bool] = None,
878
+ return_dict: Optional[bool] = None,
879
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
880
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
881
+ output_hidden_states = (
882
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
883
+ )
884
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
885
+
886
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
887
+
888
+ if self.config.attn_implementation == 'flash_attention_2':
889
+ _import_flash_attn()
890
+
891
+ # retrieve input_ids and inputs_embeds
892
+ if input_ids is not None and inputs_embeds is not None:
893
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
894
+ elif input_ids is not None:
895
+ batch_size, seq_length = input_ids.shape[:2]
896
+ elif inputs_embeds is not None:
897
+ batch_size, seq_length = inputs_embeds.shape[:2]
898
+ else:
899
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
900
+
901
+ seq_length_with_past = seq_length
902
+ past_key_values_length = 0
903
+ if past_key_values is not None:
904
+ past_key_values_length = past_key_values[0][0].shape[2]
905
+ seq_length_with_past = seq_length_with_past + past_key_values_length
906
+
907
+ if position_ids is None:
908
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
909
+ position_ids = torch.arange(
910
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
911
+ )
912
+ position_ids = position_ids.unsqueeze(0)
913
+
914
+ if inputs_embeds is None:
915
+ inputs_embeds = self.tok_embeddings(input_ids)
916
+
917
+ if self.config.attn_implementation == 'flash_attention_2':
918
+ # 2d mask is passed through the layers
919
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
920
+ else:
921
+ if attention_mask is None:
922
+ attention_mask = torch.ones(
923
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
924
+ )
925
+ attention_mask = self._prepare_decoder_attention_mask(
926
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
927
+ )
928
+
929
+ # embed positions
930
+ hidden_states = inputs_embeds
931
+
932
+ if self.gradient_checkpointing and self.training:
933
+ if use_cache:
934
+ logger.warning_once(
935
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
936
+ )
937
+ use_cache = False
938
+
939
+ # decoder layers
940
+ all_hidden_states = () if output_hidden_states else None
941
+ all_self_attns = () if output_attentions else None
942
+ next_decoder_cache = () if use_cache else None
943
+
944
+ for idx, decoder_layer in enumerate(self.layers):
945
+ if output_hidden_states:
946
+ all_hidden_states += (hidden_states,)
947
+
948
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
949
+
950
+ if self.gradient_checkpointing and self.training:
951
+
952
+ def create_custom_forward(module):
953
+ def custom_forward(*inputs):
954
+ # None for past_key_value
955
+ return module(*inputs, output_attentions, None)
956
+
957
+ return custom_forward
958
+
959
+ layer_outputs = torch.utils.checkpoint.checkpoint(
960
+ create_custom_forward(decoder_layer),
961
+ hidden_states,
962
+ attention_mask,
963
+ position_ids,
964
+ None,
965
+ )
966
+ else:
967
+ layer_outputs = decoder_layer(
968
+ hidden_states,
969
+ attention_mask=attention_mask,
970
+ position_ids=position_ids,
971
+ past_key_value=past_key_value,
972
+ output_attentions=output_attentions,
973
+ use_cache=use_cache,
974
+ )
975
+
976
+ hidden_states = layer_outputs[0]
977
+
978
+ if use_cache:
979
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
980
+
981
+ if output_attentions:
982
+ all_self_attns += (layer_outputs[1],)
983
+
984
+ hidden_states = self.norm(hidden_states)
985
+
986
+ # add hidden states from the last decoder layer
987
+ if output_hidden_states:
988
+ all_hidden_states += (hidden_states,)
989
+
990
+ next_cache = next_decoder_cache if use_cache else None
991
+ if not return_dict:
992
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
993
+ return BaseModelOutputWithPast(
994
+ last_hidden_state=hidden_states,
995
+ past_key_values=next_cache,
996
+ hidden_states=all_hidden_states,
997
+ attentions=all_self_attns,
998
+ )
999
+
1000
+
1001
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
1002
+ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1003
+ _auto_class = 'AutoModelForCausalLM'
1004
+
1005
+ _tied_weights_keys = ['output.weight']
1006
+
1007
+ def __init__(self, config):
1008
+ super().__init__(config)
1009
+ self.model = InternLM2Model(config)
1010
+ self.vocab_size = config.vocab_size
1011
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1012
+
1013
+ # Initialize weights and apply final processing
1014
+ self.post_init()
1015
+
1016
+ def get_input_embeddings(self):
1017
+ return self.model.tok_embeddings
1018
+
1019
+ def set_input_embeddings(self, value):
1020
+ self.model.tok_embeddings = value
1021
+
1022
+ def get_output_embeddings(self):
1023
+ return self.output
1024
+
1025
+ def set_output_embeddings(self, new_embeddings):
1026
+ self.output = new_embeddings
1027
+
1028
+ def set_decoder(self, decoder):
1029
+ self.model = decoder
1030
+
1031
+ def get_decoder(self):
1032
+ return self.model
1033
+
1034
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1035
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1036
+ def forward(
1037
+ self,
1038
+ input_ids: torch.LongTensor = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ position_ids: Optional[torch.LongTensor] = None,
1041
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1042
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1043
+ labels: Optional[torch.LongTensor] = None,
1044
+ use_cache: Optional[bool] = None,
1045
+ output_attentions: Optional[bool] = None,
1046
+ output_hidden_states: Optional[bool] = None,
1047
+ return_dict: Optional[bool] = None,
1048
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1049
+ r"""
1050
+ Args:
1051
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1052
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1053
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1054
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1055
+
1056
+ Returns:
1057
+
1058
+ Example:
1059
+
1060
+ ```python
1061
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
1062
+
1063
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1064
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1065
+
1066
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1067
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1068
+
1069
+ >>> # Generate
1070
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1071
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1072
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1073
+ ```"""
1074
+
1075
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1076
+ output_hidden_states = (
1077
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1078
+ )
1079
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1080
+
1081
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1082
+ outputs = self.model(
1083
+ input_ids=input_ids,
1084
+ attention_mask=attention_mask,
1085
+ position_ids=position_ids,
1086
+ past_key_values=past_key_values,
1087
+ inputs_embeds=inputs_embeds,
1088
+ use_cache=use_cache,
1089
+ output_attentions=output_attentions,
1090
+ output_hidden_states=output_hidden_states,
1091
+ return_dict=return_dict,
1092
+ )
1093
+
1094
+ hidden_states = outputs[0]
1095
+ logits = self.output(hidden_states)
1096
+ logits = logits.float()
1097
+
1098
+ loss = None
1099
+ if labels is not None:
1100
+ # Shift so that tokens < n predict n
1101
+ shift_logits = logits[..., :-1, :].contiguous()
1102
+ shift_labels = labels[..., 1:].contiguous()
1103
+ # Flatten the tokens
1104
+ loss_fct = CrossEntropyLoss()
1105
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1106
+ shift_labels = shift_labels.view(-1)
1107
+ # Enable model parallelism
1108
+ shift_labels = shift_labels.to(shift_logits.device)
1109
+ loss = loss_fct(shift_logits, shift_labels)
1110
+
1111
+ if not return_dict:
1112
+ output = (logits,) + outputs[1:]
1113
+ return (loss,) + output if loss is not None else output
1114
+
1115
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1116
+ output = CausalLMOutputWithPast(
1117
+ loss=loss,
1118
+ logits=logits,
1119
+ past_key_values=outputs.past_key_values,
1120
+ hidden_states=outputs.hidden_states,
1121
+ attentions=outputs.attentions,
1122
+ )
1123
+ output['logits'] = output['logits'].to(device)
1124
+ return output
1125
+
1126
+ def prepare_inputs_for_generation(
1127
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1128
+ ):
1129
+ if past_key_values is not None:
1130
+ past_length = past_key_values[0][0].shape[2]
1131
+
1132
+ # Some generation methods already pass only the last input ID
1133
+ if input_ids.shape[1] > past_length:
1134
+ remove_prefix_length = past_length
1135
+ else:
1136
+ # Default to old behavior: keep only final ID
1137
+ remove_prefix_length = input_ids.shape[1] - 1
1138
+
1139
+ input_ids = input_ids[:, remove_prefix_length:]
1140
+
1141
+ position_ids = kwargs.get('position_ids', None)
1142
+ if attention_mask is not None and position_ids is None:
1143
+ # create position_ids on the fly for batch generation
1144
+ position_ids = attention_mask.long().cumsum(-1) - 1
1145
+ position_ids.masked_fill_(attention_mask == 0, 1)
1146
+ if past_key_values:
1147
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1148
+
1149
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1150
+ if inputs_embeds is not None and past_key_values is None:
1151
+ model_inputs = {'inputs_embeds': inputs_embeds}
1152
+ else:
1153
+ model_inputs = {'input_ids': input_ids}
1154
+
1155
+ model_inputs.update(
1156
+ {
1157
+ 'position_ids': position_ids,
1158
+ 'past_key_values': past_key_values,
1159
+ 'use_cache': kwargs.get('use_cache'),
1160
+ 'attention_mask': attention_mask,
1161
+ }
1162
+ )
1163
+ return model_inputs
1164
+
1165
+ @staticmethod
1166
+ def _reorder_cache(past_key_values, beam_idx):
1167
+ reordered_past = ()
1168
+ for layer_past in past_key_values:
1169
+ reordered_past += (
1170
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1171
+ )
1172
+ return reordered_past
1173
+
1174
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''):
1175
+ if tokenizer.add_bos_token:
1176
+ prompt = ''
1177
+ else:
1178
+ prompt = tokenizer.bos_token
1179
+ if meta_instruction:
1180
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1181
+ for record in history:
1182
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1183
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
1184
+ return tokenizer([prompt], return_tensors='pt')
1185
+
1186
+ @torch.no_grad()
1187
+ def chat(
1188
+ self,
1189
+ tokenizer,
1190
+ query: str,
1191
+ history: List[Tuple[str, str]] = [],
1192
+ streamer: Optional[BaseStreamer] = None,
1193
+ max_new_tokens: int = 1024,
1194
+ do_sample: bool = True,
1195
+ temperature: float = 0.8,
1196
+ top_p: float = 0.8,
1197
+ meta_instruction: str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
1198
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
1199
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.',
1200
+ **kwargs,
1201
+ ):
1202
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1203
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1204
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1205
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]]
1206
+ outputs = self.generate(
1207
+ **inputs,
1208
+ streamer=streamer,
1209
+ max_new_tokens=max_new_tokens,
1210
+ do_sample=do_sample,
1211
+ temperature=temperature,
1212
+ top_p=top_p,
1213
+ eos_token_id=eos_token_id,
1214
+ **kwargs,
1215
+ )
1216
+ outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
1217
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
1218
+ response = response.split('<|im_end|>')[0]
1219
+ history = history + [(query, response)]
1220
+ return response, history
1221
+
1222
+ @torch.no_grad()
1223
+ def stream_chat(
1224
+ self,
1225
+ tokenizer,
1226
+ query: str,
1227
+ history: List[Tuple[str, str]] = [],
1228
+ max_new_tokens: int = 1024,
1229
+ do_sample: bool = True,
1230
+ temperature: float = 0.8,
1231
+ top_p: float = 0.8,
1232
+ **kwargs,
1233
+ ):
1234
+ """
1235
+ Return a generator in format: (response, history)
1236
+ Eg.
1237
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
1238
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
1239
+ """
1240
+ if BaseStreamer is None:
1241
+ raise ModuleNotFoundError(
1242
+ 'The version of `transformers` is too low. Please make sure '
1243
+ 'that you have installed `transformers>=4.28.0`.'
1244
+ )
1245
+
1246
+ response_queue = queue.Queue(maxsize=20)
1247
+
1248
+ class ChatStreamer(BaseStreamer):
1249
+ def __init__(self, tokenizer) -> None:
1250
+ super().__init__()
1251
+ self.tokenizer = tokenizer
1252
+ self.queue = response_queue
1253
+ self.query = query
1254
+ self.history = history
1255
+ self.response = ''
1256
+ self.cache = []
1257
+ self.received_inputs = False
1258
+ self.queue.put((self.response, history + [(self.query, self.response)]))
1259
+
1260
+ def put(self, value):
1261
+ if len(value.shape) > 1 and value.shape[0] > 1:
1262
+ raise ValueError('ChatStreamer only supports batch size 1')
1263
+ elif len(value.shape) > 1:
1264
+ value = value[0]
1265
+
1266
+ if not self.received_inputs:
1267
+ # The first received value is input_ids, ignore here
1268
+ self.received_inputs = True
1269
+ return
1270
+
1271
+ self.cache.extend(value.tolist())
1272
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
1273
+ if token.strip() != '<|im_end|>':
1274
+ self.response = self.response + token
1275
+ history = self.history + [(self.query, self.response)]
1276
+ self.queue.put((self.response, history))
1277
+ self.cache = []
1278
+ else:
1279
+ self.end()
1280
+
1281
+ def end(self):
1282
+ self.queue.put(None)
1283
+
1284
+ def stream_producer():
1285
+ return self.chat(
1286
+ tokenizer=tokenizer,
1287
+ query=query,
1288
+ streamer=ChatStreamer(tokenizer=tokenizer),
1289
+ history=history,
1290
+ max_new_tokens=max_new_tokens,
1291
+ do_sample=do_sample,
1292
+ temperature=temperature,
1293
+ top_p=top_p,
1294
+ **kwargs,
1295
+ )
1296
+
1297
+ def consumer():
1298
+ producer = threading.Thread(target=stream_producer)
1299
+ producer.start()
1300
+ while True:
1301
+ res = response_queue.get()
1302
+ if res is None:
1303
+ return
1304
+ yield res
1305
+
1306
+ return consumer()
1307
+
1308
+
1309
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1310
+ @add_start_docstrings(
1311
+ """
1312
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
1313
+
1314
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
1315
+ as other causal models (e.g. GPT-2) do.
1316
+
1317
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1318
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1319
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1320
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1321
+ each row of the batch).
1322
+ """,
1323
+ InternLM2_START_DOCSTRING,
1324
+ )
1325
+ class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
1326
+ def __init__(self, config):
1327
+ super().__init__(config)
1328
+ self.num_labels = config.num_labels
1329
+ self.model = InternLM2Model(config)
1330
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1331
+
1332
+ # Initialize weights and apply final processing
1333
+ self.post_init()
1334
+
1335
+ def get_input_embeddings(self):
1336
+ return self.model.tok_embeddings
1337
+
1338
+ def set_input_embeddings(self, value):
1339
+ self.model.tok_embeddings = value
1340
+
1341
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
1342
+ def forward(
1343
+ self,
1344
+ input_ids: torch.LongTensor = None,
1345
+ attention_mask: Optional[torch.Tensor] = None,
1346
+ position_ids: Optional[torch.LongTensor] = None,
1347
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1348
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1349
+ labels: Optional[torch.LongTensor] = None,
1350
+ use_cache: Optional[bool] = None,
1351
+ output_attentions: Optional[bool] = None,
1352
+ output_hidden_states: Optional[bool] = None,
1353
+ return_dict: Optional[bool] = None,
1354
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1355
+ r"""
1356
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1357
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1358
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1359
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1360
+ """
1361
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1362
+
1363
+ transformer_outputs = self.model(
1364
+ input_ids,
1365
+ attention_mask=attention_mask,
1366
+ position_ids=position_ids,
1367
+ past_key_values=past_key_values,
1368
+ inputs_embeds=inputs_embeds,
1369
+ use_cache=use_cache,
1370
+ output_attentions=output_attentions,
1371
+ output_hidden_states=output_hidden_states,
1372
+ return_dict=return_dict,
1373
+ )
1374
+ hidden_states = transformer_outputs[0]
1375
+ logits = self.score(hidden_states)
1376
+
1377
+ if input_ids is not None:
1378
+ batch_size = input_ids.shape[0]
1379
+ else:
1380
+ batch_size = inputs_embeds.shape[0]
1381
+
1382
+ if self.config.pad_token_id is None and batch_size != 1:
1383
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1384
+ if self.config.pad_token_id is None:
1385
+ sequence_lengths = -1
1386
+ else:
1387
+ if input_ids is not None:
1388
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1389
+ logits.device
1390
+ )
1391
+ else:
1392
+ sequence_lengths = -1
1393
+
1394
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1395
+
1396
+ loss = None
1397
+ if labels is not None:
1398
+ labels = labels.to(logits.device)
1399
+ if self.config.problem_type is None:
1400
+ if self.num_labels == 1:
1401
+ self.config.problem_type = 'regression'
1402
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1403
+ self.config.problem_type = 'single_label_classification'
1404
+ else:
1405
+ self.config.problem_type = 'multi_label_classification'
1406
+
1407
+ if self.config.problem_type == 'regression':
1408
+ loss_fct = MSELoss()
1409
+ if self.num_labels == 1:
1410
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1411
+ else:
1412
+ loss = loss_fct(pooled_logits, labels)
1413
+ elif self.config.problem_type == 'single_label_classification':
1414
+ loss_fct = CrossEntropyLoss()
1415
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1416
+ elif self.config.problem_type == 'multi_label_classification':
1417
+ loss_fct = BCEWithLogitsLoss()
1418
+ loss = loss_fct(pooled_logits, labels)
1419
+ if not return_dict:
1420
+ output = (pooled_logits,) + transformer_outputs[1:]
1421
+ return ((loss,) + output) if loss is not None else output
1422
+
1423
+ return SequenceClassifierOutputWithPast(
1424
+ loss=loss,
1425
+ logits=pooled_logits,
1426
+ past_key_values=transformer_outputs.past_key_values,
1427
+ hidden_states=transformer_outputs.hidden_states,
1428
+ attentions=transformer_outputs.attentions,
1429
+ )
modeling_phi3.py ADDED
@@ -0,0 +1,1610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """ PyTorch Phi-3 model."""
16
+
17
+ import inspect
18
+ import math
19
+ import warnings
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.modeling_attn_mask_utils import \
30
+ _prepare_4d_causal_attention_mask
31
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ SequenceClassifierOutputWithPast,
34
+ TokenClassifierOutput)
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ is_flash_attn_2_available,
40
+ is_flash_attn_greater_or_equal_2_10, logging,
41
+ replace_return_docstrings)
42
+
43
+ from .configuration_phi3 import Phi3Config
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ # Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
48
+ # if is_flash_attn_2_available():
49
+ _flash_supports_window_size = False
50
+ try:
51
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
52
+ from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
53
+ unpad_input)
54
+
55
+ _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters)
56
+ has_flash_attn = True
57
+ except ImportError as error:
58
+ logger.warning(
59
+ f'`flash-attention` package not found, consider installing for better performance: {error}.'
60
+ )
61
+ if not _flash_supports_window_size:
62
+ logger.warning(
63
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
64
+ )
65
+ has_flash_attn = False
66
+
67
+ _CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct'
68
+ _CONFIG_FOR_DOC = 'Phi3Config'
69
+
70
+ PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
71
+ 'microsoft/Phi-3-mini-4k-instruct',
72
+ 'microsoft/Phi-3-mini-128k-instruct',
73
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
74
+ ]
75
+
76
+
77
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
78
+ class Phi3RMSNorm(nn.Module):
79
+ def __init__(self, hidden_size, eps=1e-6):
80
+ """
81
+ Phi3RMSNorm is equivalent to T5LayerNorm
82
+ """
83
+ super().__init__()
84
+ self.weight = nn.Parameter(torch.ones(hidden_size))
85
+ self.variance_epsilon = eps
86
+
87
+ def forward(self, hidden_states):
88
+ input_dtype = hidden_states.dtype
89
+ hidden_states = hidden_states.to(torch.float32)
90
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
91
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
+ return self.weight * hidden_states.to(input_dtype)
93
+
94
+
95
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
96
+ def _get_unpad_data(attention_mask):
97
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
98
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
99
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
100
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
101
+ return (
102
+ indices,
103
+ cu_seqlens,
104
+ max_seqlen_in_batch,
105
+ )
106
+
107
+
108
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
109
+ class Phi3RotaryEmbedding(nn.Module):
110
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
111
+ super().__init__()
112
+
113
+ self.dim = dim
114
+ self.max_position_embeddings = max_position_embeddings
115
+ self.base = base
116
+ self.register_buffer('inv_freq', None, persistent=False)
117
+
118
+ @torch.no_grad()
119
+ def forward(self, x, position_ids, seq_len=None):
120
+ # x: [bs, num_attention_heads, seq_len, head_size]
121
+ if self.inv_freq is None:
122
+ self.inv_freq = 1.0 / (
123
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
124
+ )
125
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
126
+ position_ids_expanded = position_ids[:, None, :].float()
127
+ # Force float32 since bfloat16 loses precision on long contexts
128
+ # See https://github.com/huggingface/transformers/pull/29285
129
+ device_type = x.device.type
130
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
131
+ with torch.autocast(device_type=device_type, enabled=False):
132
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
133
+ emb = torch.cat((freqs, freqs), dim=-1)
134
+ cos = emb.cos()
135
+ sin = emb.sin()
136
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
137
+
138
+
139
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
140
+ def __init__(self, dim, config, device=None):
141
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
142
+
143
+ self.short_factor = config.rope_scaling['short_factor']
144
+ self.long_factor = config.rope_scaling['long_factor']
145
+ self.original_max_position_embeddings = config.original_max_position_embeddings
146
+
147
+ @torch.no_grad()
148
+ def forward(self, x, position_ids, seq_len=None):
149
+ seq_len = torch.max(position_ids) + 1
150
+ if seq_len > self.original_max_position_embeddings:
151
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
152
+ else:
153
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
154
+
155
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
156
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
157
+
158
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
159
+ position_ids_expanded = position_ids[:, None, :].float()
160
+
161
+ # Force float32 since bfloat16 loses precision on long contexts
162
+ # See https://github.com/huggingface/transformers/pull/29285
163
+ device_type = x.device.type
164
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
165
+ with torch.autocast(device_type=device_type, enabled=False):
166
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
167
+ emb = torch.cat((freqs, freqs), dim=-1)
168
+
169
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
170
+ if scale <= 1.0:
171
+ scaling_factor = 1.0
172
+ else:
173
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
174
+
175
+ cos = emb.cos() * scaling_factor
176
+ sin = emb.sin() * scaling_factor
177
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
178
+
179
+
180
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
181
+ def __init__(self, dim, config, device=None):
182
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
183
+
184
+ self.short_factor = config.rope_scaling['short_factor']
185
+ self.long_factor = config.rope_scaling['long_factor']
186
+ self.original_max_position_embeddings = config.original_max_position_embeddings
187
+
188
+ @torch.no_grad()
189
+ def forward(self, x, position_ids, seq_len=None):
190
+ seq_len = torch.max(position_ids) + 1
191
+ if seq_len > self.original_max_position_embeddings:
192
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
193
+ else:
194
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
195
+
196
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
197
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
198
+
199
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
200
+ position_ids_expanded = position_ids[:, None, :].float()
201
+
202
+ # Force float32 since bfloat16 loses precision on long contexts
203
+ # See https://github.com/huggingface/transformers/pull/29285
204
+ device_type = x.device.type
205
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
206
+ with torch.autocast(device_type=device_type, enabled=False):
207
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
208
+ emb = torch.cat((freqs, freqs), dim=-1)
209
+
210
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
211
+ if scale <= 1.0:
212
+ scaling_factor = 1.0
213
+ else:
214
+ scaling_factor = 0.1 * math.log(scale) + 1.0
215
+
216
+ cos = emb.cos() * scaling_factor
217
+ sin = emb.sin() * scaling_factor
218
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
219
+
220
+
221
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
222
+ def rotate_half(x):
223
+ """Rotates half the hidden dims of the input."""
224
+ x1 = x[..., : x.shape[-1] // 2]
225
+ x2 = x[..., x.shape[-1] // 2 :]
226
+ return torch.cat((-x2, x1), dim=-1)
227
+
228
+
229
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding to the query and key tensors.
232
+
233
+ Args:
234
+ q (`torch.Tensor`): The query tensor.
235
+ k (`torch.Tensor`): The key tensor.
236
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
237
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
238
+ position_ids (`torch.Tensor`, *optional*):
239
+ Deprecated and unused.
240
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
241
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
242
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
243
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
244
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
245
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
246
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
247
+ Returns:
248
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
249
+ """
250
+ cos = cos.unsqueeze(unsqueeze_dim)
251
+ sin = sin.unsqueeze(unsqueeze_dim)
252
+ q_embed = (q * cos) + (rotate_half(q) * sin)
253
+ k_embed = (k * cos) + (rotate_half(k) * sin)
254
+ return q_embed, k_embed
255
+
256
+
257
+ class Phi3MLP(nn.Module):
258
+ def __init__(self, config):
259
+ super().__init__()
260
+
261
+ self.config = config
262
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
263
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
264
+
265
+ self.activation_fn = ACT2FN[config.hidden_act]
266
+
267
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
268
+ up_states = self.gate_up_proj(hidden_states)
269
+
270
+ gate, up_states = up_states.chunk(2, dim=-1)
271
+ up_states = up_states * self.activation_fn(gate)
272
+
273
+ return self.down_proj(up_states)
274
+
275
+
276
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
277
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
278
+ """
279
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
280
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
281
+ """
282
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
283
+ if n_rep == 1:
284
+ return hidden_states
285
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
286
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
287
+
288
+
289
+ class Phi3Attention(nn.Module):
290
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
291
+
292
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
293
+ super().__init__()
294
+ self.config = config
295
+ self.layer_idx = layer_idx
296
+ if layer_idx is None:
297
+ logger.warning_once(
298
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
299
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
300
+ 'when creating this class.'
301
+ )
302
+
303
+ self.attention_dropout = config.attention_dropout
304
+ self.hidden_size = config.hidden_size
305
+ self.num_heads = config.num_attention_heads
306
+ self.head_dim = self.hidden_size // self.num_heads
307
+ self.num_key_value_heads = config.num_key_value_heads
308
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
309
+ self.max_position_embeddings = config.max_position_embeddings
310
+ self.original_max_position_embeddings = config.original_max_position_embeddings
311
+ self.rope_theta = config.rope_theta
312
+ self.rope_scaling = config.rope_scaling
313
+ self.is_causal = True
314
+
315
+ if (self.head_dim * self.num_heads) != self.hidden_size:
316
+ raise ValueError(
317
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
318
+ f' and `num_heads`: {self.num_heads}).'
319
+ )
320
+
321
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
322
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
323
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
324
+ self._init_rope()
325
+
326
+ def _init_rope(self):
327
+ if self.rope_scaling is None:
328
+ self.rotary_emb = Phi3RotaryEmbedding(
329
+ self.head_dim,
330
+ max_position_embeddings=self.max_position_embeddings,
331
+ base=self.rope_theta,
332
+ )
333
+ else:
334
+ scaling_type = self.config.rope_scaling['type']
335
+ if scaling_type == 'su':
336
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
337
+ elif scaling_type == 'yarn':
338
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
339
+ else:
340
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ position_ids: Optional[torch.LongTensor] = None,
347
+ past_key_value: Optional[Cache] = None,
348
+ output_attentions: bool = False,
349
+ use_cache: bool = False,
350
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
351
+ logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.')
352
+
353
+ bsz, q_len, _ = hidden_states.size()
354
+
355
+ qkv = self.qkv_proj(hidden_states)
356
+ query_pos = self.num_heads * self.head_dim
357
+ query_states = qkv[..., :query_pos]
358
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
359
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
360
+
361
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
362
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
363
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
364
+
365
+ kv_seq_len = key_states.shape[-2]
366
+ if past_key_value is not None:
367
+ if self.layer_idx is None:
368
+ raise ValueError(
369
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
370
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
371
+ 'with a layer index.'
372
+ )
373
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
374
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
375
+
376
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
377
+
378
+ if past_key_value is not None:
379
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
380
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
381
+
382
+ # repeat k/v heads if n_kv_heads < n_heads
383
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
384
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
385
+
386
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
387
+
388
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
389
+ raise ValueError(
390
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
391
+ f' {attn_weights.size()}'
392
+ )
393
+
394
+ if attention_mask is not None:
395
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
396
+ raise ValueError(
397
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
398
+ )
399
+ attn_weights = attn_weights + attention_mask
400
+
401
+ # upcast attention to fp32
402
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
403
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
404
+
405
+ attn_output = torch.matmul(attn_weights, value_states)
406
+
407
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
408
+ raise ValueError(
409
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
410
+ f' {attn_output.size()}'
411
+ )
412
+
413
+ attn_output = attn_output.transpose(1, 2).contiguous()
414
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
415
+
416
+ attn_output = self.o_proj(attn_output)
417
+
418
+ if not output_attentions:
419
+ attn_weights = None
420
+
421
+ return attn_output, attn_weights, past_key_value
422
+
423
+
424
+ class Phi3FlashAttention2(Phi3Attention):
425
+ """
426
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
427
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
428
+ flash attention and deal with padding tokens in case the input contains any of them.
429
+ """
430
+
431
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
432
+ def __init__(self, *args, **kwargs):
433
+ super().__init__(*args, **kwargs)
434
+
435
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
436
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
437
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
438
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ position_ids: Optional[torch.LongTensor] = None,
445
+ past_key_value: Optional[Cache] = None,
446
+ output_attentions: bool = False,
447
+ use_cache: bool = False,
448
+ **kwargs,
449
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
450
+ # Phi3FlashAttention2 attention does not support output_attentions
451
+
452
+ if not _flash_supports_window_size:
453
+ logger.warning_once(
454
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
455
+ )
456
+ raise ValueError('The current flash attention version does not support sliding window attention.')
457
+
458
+ output_attentions = False
459
+
460
+ if 'padding_mask' in kwargs:
461
+ warnings.warn(
462
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
463
+ )
464
+
465
+ # overwrite attention_mask with padding_mask
466
+ attention_mask = kwargs.pop('padding_mask')
467
+
468
+ bsz, q_len, _ = hidden_states.size()
469
+
470
+ qkv = self.qkv_proj(hidden_states)
471
+ query_pos = self.num_heads * self.head_dim
472
+ query_states = qkv[..., :query_pos]
473
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
474
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
475
+
476
+ # Flash attention requires the input to have the shape
477
+ # batch_size x seq_length x head_dim x hidden_dim
478
+ # therefore we just need to keep the original shape
479
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
480
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
481
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
482
+
483
+ kv_seq_len = key_states.shape[-2]
484
+ if past_key_value is not None:
485
+ if self.layer_idx is None:
486
+ raise ValueError(
487
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
488
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
489
+ 'with a layer index.'
490
+ )
491
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
492
+
493
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
494
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
495
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
496
+
497
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
498
+
499
+ use_sliding_windows = (
500
+ _flash_supports_window_size
501
+ and getattr(self.config, 'sliding_window', None) is not None
502
+ and kv_seq_len > self.config.sliding_window
503
+ )
504
+
505
+ if past_key_value is not None:
506
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
507
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
508
+ if (
509
+ getattr(self.config, 'sliding_window', None) is not None
510
+ and kv_seq_len > self.config.sliding_window
511
+ and cache_has_contents
512
+ ):
513
+ slicing_tokens = 1 - self.config.sliding_window
514
+
515
+ past_key = past_key_value[self.layer_idx][0]
516
+ past_value = past_key_value[self.layer_idx][1]
517
+
518
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
519
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
520
+
521
+ if past_key.shape[-2] != self.config.sliding_window - 1:
522
+ raise ValueError(
523
+ f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
524
+ f' {past_key.shape}'
525
+ )
526
+
527
+ if attention_mask is not None:
528
+ attention_mask = attention_mask[:, slicing_tokens:]
529
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
530
+
531
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
532
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
533
+
534
+ # repeat k/v heads if n_kv_heads < n_heads
535
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
536
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
537
+
538
+ attn_dropout = self.attention_dropout if self.training else 0.0
539
+
540
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
541
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
542
+ # cast them back in the correct dtype just to be sure everything works as expected.
543
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
544
+ # in fp32.
545
+
546
+ if query_states.dtype == torch.float32:
547
+ if torch.is_autocast_enabled():
548
+ target_dtype = torch.get_autocast_gpu_dtype()
549
+ # Handle the case where the model is quantized
550
+ elif hasattr(self.config, '_pre_quantization_dtype'):
551
+ target_dtype = self.config._pre_quantization_dtype
552
+ else:
553
+ target_dtype = self.qkv_proj.weight.dtype
554
+
555
+ logger.warning_once(
556
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
557
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
558
+ f' {target_dtype}.'
559
+ )
560
+
561
+ query_states = query_states.to(target_dtype)
562
+ key_states = key_states.to(target_dtype)
563
+ value_states = value_states.to(target_dtype)
564
+
565
+ # Reashape to the expected shape for Flash Attention
566
+ query_states = query_states.transpose(1, 2)
567
+ key_states = key_states.transpose(1, 2)
568
+ value_states = value_states.transpose(1, 2)
569
+
570
+ attn_output = self._flash_attention_forward(
571
+ query_states,
572
+ key_states,
573
+ value_states,
574
+ attention_mask,
575
+ q_len,
576
+ dropout=attn_dropout,
577
+ use_sliding_windows=use_sliding_windows,
578
+ )
579
+
580
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
581
+ attn_output = self.o_proj(attn_output)
582
+
583
+ if not output_attentions:
584
+ attn_weights = None
585
+
586
+ return attn_output, attn_weights, past_key_value
587
+
588
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
589
+ def _flash_attention_forward(
590
+ self,
591
+ query_states,
592
+ key_states,
593
+ value_states,
594
+ attention_mask,
595
+ query_length,
596
+ dropout=0.0,
597
+ softmax_scale=None,
598
+ use_sliding_windows=False,
599
+ ):
600
+ """
601
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
602
+ first unpad the input, then computes the attention scores and pad the final attention scores.
603
+
604
+ Args:
605
+ query_states (`torch.Tensor`):
606
+ Input query states to be passed to Flash Attention API
607
+ key_states (`torch.Tensor`):
608
+ Input key states to be passed to Flash Attention API
609
+ value_states (`torch.Tensor`):
610
+ Input value states to be passed to Flash Attention API
611
+ attention_mask (`torch.Tensor`):
612
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
613
+ position of padding tokens and 1 for the position of non-padding tokens.
614
+ dropout (`float`):
615
+ Attention dropout
616
+ softmax_scale (`float`, *optional*):
617
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
618
+ use_sliding_windows (`bool`, *optional*):
619
+ Whether to activate sliding window attention.
620
+ """
621
+ if not self._flash_attn_uses_top_left_mask:
622
+ causal = self.is_causal
623
+ else:
624
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
625
+ causal = self.is_causal and query_length != 1
626
+
627
+ # Contains at least one padding token in the sequence
628
+ if attention_mask is not None:
629
+ batch_size = query_states.shape[0]
630
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
631
+ query_states, key_states, value_states, attention_mask, query_length
632
+ )
633
+
634
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
635
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
636
+
637
+ if not use_sliding_windows:
638
+ attn_output_unpad = flash_attn_varlen_func(
639
+ query_states,
640
+ key_states,
641
+ value_states,
642
+ cu_seqlens_q=cu_seqlens_q,
643
+ cu_seqlens_k=cu_seqlens_k,
644
+ max_seqlen_q=max_seqlen_in_batch_q,
645
+ max_seqlen_k=max_seqlen_in_batch_k,
646
+ dropout_p=dropout,
647
+ softmax_scale=softmax_scale,
648
+ causal=causal,
649
+ )
650
+ else:
651
+ attn_output_unpad = flash_attn_varlen_func(
652
+ query_states,
653
+ key_states,
654
+ value_states,
655
+ cu_seqlens_q=cu_seqlens_q,
656
+ cu_seqlens_k=cu_seqlens_k,
657
+ max_seqlen_q=max_seqlen_in_batch_q,
658
+ max_seqlen_k=max_seqlen_in_batch_k,
659
+ dropout_p=dropout,
660
+ softmax_scale=softmax_scale,
661
+ causal=causal,
662
+ window_size=(self.config.sliding_window, self.config.sliding_window),
663
+ )
664
+
665
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
666
+ else:
667
+ if not use_sliding_windows:
668
+ attn_output = flash_attn_func(
669
+ query_states,
670
+ key_states,
671
+ value_states,
672
+ dropout,
673
+ softmax_scale=softmax_scale,
674
+ causal=causal,
675
+ )
676
+ else:
677
+ attn_output = flash_attn_func(
678
+ query_states,
679
+ key_states,
680
+ value_states,
681
+ dropout,
682
+ softmax_scale=softmax_scale,
683
+ causal=causal,
684
+ window_size=(self.config.sliding_window, self.config.sliding_window),
685
+ )
686
+
687
+ return attn_output
688
+
689
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
690
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
691
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
692
+
693
+ # On the first iteration we need to properly re-create the padding mask
694
+ # by slicing it on the proper place
695
+ if kv_seq_len != attention_mask.shape[-1]:
696
+ attention_mask_num_tokens = attention_mask.shape[-1]
697
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
698
+
699
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
700
+
701
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
702
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
703
+
704
+ if query_length == kv_seq_len:
705
+ query_layer = index_first_axis(
706
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
707
+ )
708
+ cu_seqlens_q = cu_seqlens_k
709
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
710
+ indices_q = indices_k
711
+ elif query_length == 1:
712
+ max_seqlen_in_batch_q = 1
713
+ cu_seqlens_q = torch.arange(
714
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
715
+ ) # There is a memcpy here, that is very bad.
716
+ indices_q = cu_seqlens_q[:-1]
717
+ query_layer = query_layer.squeeze(1)
718
+ else:
719
+ # The -q_len: slice assumes left padding.
720
+ attention_mask = attention_mask[:, -query_length:]
721
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
722
+
723
+ return (
724
+ query_layer,
725
+ key_layer,
726
+ value_layer,
727
+ indices_q,
728
+ (cu_seqlens_q, cu_seqlens_k),
729
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
730
+ )
731
+
732
+
733
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
734
+ # TODO @Arthur no longer copied from LLama after static cache
735
+ class Phi3SdpaAttention(Phi3Attention):
736
+ """
737
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
738
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
739
+ SDPA API.
740
+ """
741
+
742
+ # Adapted from Phi3Attention.forward
743
+ def forward(
744
+ self,
745
+ hidden_states: torch.Tensor,
746
+ attention_mask: Optional[torch.Tensor] = None,
747
+ position_ids: Optional[torch.LongTensor] = None,
748
+ past_key_value: Optional[Cache] = None,
749
+ output_attentions: bool = False,
750
+ use_cache: bool = False,
751
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
752
+ if output_attentions:
753
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
754
+ logger.warning_once(
755
+ 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
756
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
757
+ )
758
+ return super().forward(
759
+ hidden_states=hidden_states,
760
+ attention_mask=attention_mask,
761
+ position_ids=position_ids,
762
+ past_key_value=past_key_value,
763
+ output_attentions=output_attentions,
764
+ use_cache=use_cache,
765
+ )
766
+
767
+ bsz, q_len, _ = hidden_states.size()
768
+
769
+ qkv = self.qkv_proj(hidden_states)
770
+ query_pos = self.num_heads * self.head_dim
771
+ query_states = qkv[..., :query_pos]
772
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
773
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
774
+
775
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
776
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
777
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
778
+
779
+ kv_seq_len = key_states.shape[-2]
780
+ if past_key_value is not None:
781
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
782
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
783
+
784
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
785
+
786
+ if past_key_value is not None:
787
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
788
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
789
+
790
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
791
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
792
+
793
+ if attention_mask is not None:
794
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
795
+ raise ValueError(
796
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
797
+ )
798
+
799
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
800
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
801
+ if query_states.device.type == 'cuda' and attention_mask is not None:
802
+ query_states = query_states.contiguous()
803
+ key_states = key_states.contiguous()
804
+ value_states = value_states.contiguous()
805
+
806
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
807
+ query_states,
808
+ key_states,
809
+ value_states,
810
+ attn_mask=attention_mask,
811
+ dropout_p=self.attention_dropout if self.training else 0.0,
812
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
813
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
814
+ )
815
+
816
+ attn_output = attn_output.transpose(1, 2).contiguous()
817
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
818
+
819
+ attn_output = self.o_proj(attn_output)
820
+
821
+ return attn_output, None, past_key_value
822
+
823
+
824
+ PHI3_ATTENTION_CLASSES = {
825
+ 'eager': Phi3Attention,
826
+ 'flash_attention_2': Phi3FlashAttention2,
827
+ 'sdpa': Phi3SdpaAttention,
828
+ }
829
+
830
+
831
+ class Phi3DecoderLayer(nn.Module):
832
+ def __init__(self, config: Phi3Config, layer_idx: int):
833
+ super().__init__()
834
+
835
+ self.config = config
836
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
837
+
838
+ self.mlp = Phi3MLP(config)
839
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
840
+
841
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
842
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
843
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
844
+
845
+ def forward(
846
+ self,
847
+ hidden_states: torch.Tensor,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ position_ids: Optional[torch.LongTensor] = None,
850
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
851
+ output_attentions: Optional[bool] = False,
852
+ use_cache: Optional[bool] = False,
853
+ **kwargs,
854
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
855
+ if 'padding_mask' in kwargs:
856
+ warnings.warn(
857
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
858
+ )
859
+ """
860
+ Args:
861
+ hidden_states (`torch.FloatTensor`):
862
+ input to the layer of shape `(batch, seq_len, embed_dim)`
863
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
864
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
865
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
866
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
867
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
868
+ output_attentions (`bool`, *optional*):
869
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
870
+ returned tensors for more detail.
871
+ use_cache (`bool`, *optional*):
872
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
873
+ (see `past_key_values`).
874
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
875
+ """
876
+
877
+ residual = hidden_states
878
+
879
+ hidden_states = self.input_layernorm(hidden_states)
880
+
881
+ # Self Attention
882
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
883
+ hidden_states=hidden_states,
884
+ attention_mask=attention_mask,
885
+ position_ids=position_ids,
886
+ past_key_value=past_key_value,
887
+ output_attentions=output_attentions,
888
+ use_cache=use_cache,
889
+ )
890
+
891
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
892
+
893
+ residual = hidden_states
894
+ hidden_states = self.post_attention_layernorm(hidden_states)
895
+ hidden_states = self.mlp(hidden_states)
896
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
897
+
898
+ outputs = (hidden_states,)
899
+
900
+ if output_attentions:
901
+ outputs += (self_attn_weights,)
902
+
903
+ if use_cache:
904
+ outputs += (present_key_value,)
905
+
906
+ return outputs
907
+
908
+
909
+ PHI3_START_DOCSTRING = r"""
910
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
911
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
912
+ etc.)
913
+
914
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
915
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
916
+ and behavior.
917
+
918
+ Parameters:
919
+ config ([`Phi3Config`]):
920
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
921
+ load the weights associated with the model, only the configuration. Check out the
922
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
923
+ """
924
+
925
+
926
+ @add_start_docstrings(
927
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
928
+ PHI3_START_DOCSTRING,
929
+ )
930
+ class Phi3PreTrainedModel(PreTrainedModel):
931
+ config_class = Phi3Config
932
+ base_model_prefix = 'model'
933
+ supports_gradient_checkpointing = True
934
+ _no_split_modules = ['Phi3DecoderLayer']
935
+ _skip_keys_device_placement = 'past_key_values'
936
+ _supports_flash_attn_2 = True
937
+ _supports_sdpa = False
938
+ _supports_cache_class = True
939
+
940
+ _version = '0.0.5'
941
+
942
+ def __init__(self, config: Phi3Config):
943
+ if not has_flash_attn:
944
+ config._attn_implementation = 'eager'
945
+ print('Warning: Flash attention is not available, using eager attention instead.')
946
+ super().__init__(config)
947
+
948
+ def _init_weights(self, module):
949
+ std = self.config.initializer_range
950
+ if isinstance(module, nn.Linear):
951
+ module.weight.data.normal_(mean=0.0, std=std)
952
+ if module.bias is not None:
953
+ module.bias.data.zero_()
954
+ elif isinstance(module, nn.Embedding):
955
+ module.weight.data.normal_(mean=0.0, std=std)
956
+ if module.padding_idx is not None:
957
+ module.weight.data[module.padding_idx].zero_()
958
+
959
+
960
+ PHI3_INPUTS_DOCSTRING = r"""
961
+ Args:
962
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
963
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
964
+ it.
965
+
966
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
967
+ [`PreTrainedTokenizer.__call__`] for details.
968
+
969
+ [What are input IDs?](../glossary#input-ids)
970
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
971
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
972
+
973
+ - 1 for tokens that are **not masked**,
974
+ - 0 for tokens that are **masked**.
975
+
976
+ [What are attention masks?](../glossary#attention-mask)
977
+
978
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
979
+ [`PreTrainedTokenizer.__call__`] for details.
980
+
981
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
982
+ `past_key_values`).
983
+
984
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
985
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
986
+ information on the default strategy.
987
+
988
+ - 1 indicates the head is **not masked**,
989
+ - 0 indicates the head is **masked**.
990
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
991
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
992
+ config.n_positions - 1]`.
993
+
994
+ [What are position IDs?](../glossary#position-ids)
995
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
996
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
997
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
998
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
999
+
1000
+ Two formats are allowed:
1001
+ - a [`~cache_utils.Cache`] instance;
1002
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1003
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1004
+ cache format.
1005
+
1006
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1007
+ legacy cache format will be returned.
1008
+
1009
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1010
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1011
+ of shape `(batch_size, sequence_length)`.
1012
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1013
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1014
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1015
+ model's internal embedding lookup matrix.
1016
+ use_cache (`bool`, *optional*):
1017
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1018
+ `past_key_values`).
1019
+ output_attentions (`bool`, *optional*):
1020
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1021
+ tensors for more detail.
1022
+ output_hidden_states (`bool`, *optional*):
1023
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1024
+ more detail.
1025
+ return_dict (`bool`, *optional*):
1026
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1027
+ """
1028
+
1029
+
1030
+ @add_start_docstrings(
1031
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
1032
+ PHI3_START_DOCSTRING,
1033
+ )
1034
+ class Phi3Model(Phi3PreTrainedModel):
1035
+ """
1036
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
1037
+
1038
+ Args:
1039
+ config: Phi3Config
1040
+ """
1041
+
1042
+ def __init__(self, config: Phi3Config):
1043
+ super().__init__(config)
1044
+ self.padding_idx = config.pad_token_id
1045
+ self.vocab_size = config.vocab_size
1046
+
1047
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1048
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1049
+ self.layers = nn.ModuleList(
1050
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1051
+ )
1052
+ self._attn_implementation = config._attn_implementation
1053
+
1054
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1055
+
1056
+ self.gradient_checkpointing = False
1057
+ # Initialize weights and apply final processing
1058
+ self.post_init()
1059
+
1060
+ def get_input_embeddings(self):
1061
+ return self.embed_tokens
1062
+
1063
+ def set_input_embeddings(self, value):
1064
+ self.embed_tokens = value
1065
+
1066
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1067
+ def forward(
1068
+ self,
1069
+ input_ids: torch.LongTensor = None,
1070
+ attention_mask: Optional[torch.Tensor] = None,
1071
+ position_ids: Optional[torch.LongTensor] = None,
1072
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1073
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1074
+ use_cache: Optional[bool] = None,
1075
+ output_attentions: Optional[bool] = None,
1076
+ output_hidden_states: Optional[bool] = None,
1077
+ return_dict: Optional[bool] = None,
1078
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1079
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1080
+ output_hidden_states = (
1081
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1082
+ )
1083
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1084
+
1085
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1086
+
1087
+ # retrieve input_ids and inputs_embeds
1088
+ if input_ids is not None and inputs_embeds is not None:
1089
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
1090
+ elif input_ids is not None:
1091
+ batch_size, seq_length = input_ids.shape[:2]
1092
+ elif inputs_embeds is not None:
1093
+ batch_size, seq_length = inputs_embeds.shape[:2]
1094
+ else:
1095
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
1096
+
1097
+ past_key_values_length = 0
1098
+
1099
+ if self.gradient_checkpointing and self.training:
1100
+ if use_cache:
1101
+ logger.warning_once(
1102
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
1103
+ )
1104
+ use_cache = False
1105
+
1106
+ if use_cache:
1107
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1108
+ if use_legacy_cache:
1109
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1110
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1111
+
1112
+ if position_ids is None:
1113
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1114
+ position_ids = torch.arange(
1115
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1116
+ )
1117
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1118
+ else:
1119
+ position_ids = position_ids.view(-1, seq_length).long()
1120
+
1121
+ if inputs_embeds is None:
1122
+ inputs_embeds = self.embed_tokens(input_ids)
1123
+
1124
+ if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
1125
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1126
+ if is_padding_right:
1127
+ raise ValueError(
1128
+ "You are attempting to perform batched generation with padding_side='right'"
1129
+ ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to '
1130
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1131
+ )
1132
+
1133
+ if self._attn_implementation == 'flash_attention_2':
1134
+ # 2d mask is passed through the layers
1135
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1136
+ else:
1137
+ # 4d mask is passed through the layers
1138
+ attention_mask = _prepare_4d_causal_attention_mask(
1139
+ attention_mask,
1140
+ (batch_size, seq_length),
1141
+ inputs_embeds,
1142
+ past_key_values_length,
1143
+ sliding_window=self.config.sliding_window,
1144
+ )
1145
+
1146
+ hidden_states = inputs_embeds
1147
+
1148
+ # decoder layers
1149
+ all_hidden_states = () if output_hidden_states else None
1150
+ all_self_attns = () if output_attentions else None
1151
+ next_decoder_cache = None
1152
+
1153
+ for decoder_layer in self.layers:
1154
+ if output_hidden_states:
1155
+ all_hidden_states += (hidden_states,)
1156
+
1157
+ if self.gradient_checkpointing and self.training:
1158
+ layer_outputs = self._gradient_checkpointing_func(
1159
+ decoder_layer.__call__,
1160
+ hidden_states,
1161
+ attention_mask,
1162
+ position_ids,
1163
+ past_key_values,
1164
+ output_attentions,
1165
+ use_cache,
1166
+ )
1167
+ else:
1168
+ layer_outputs = decoder_layer(
1169
+ hidden_states,
1170
+ attention_mask=attention_mask,
1171
+ position_ids=position_ids,
1172
+ past_key_value=past_key_values,
1173
+ output_attentions=output_attentions,
1174
+ use_cache=use_cache,
1175
+ )
1176
+
1177
+ hidden_states = layer_outputs[0]
1178
+
1179
+ if use_cache:
1180
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1181
+
1182
+ if output_attentions:
1183
+ all_self_attns += (layer_outputs[1],)
1184
+
1185
+ hidden_states = self.norm(hidden_states)
1186
+
1187
+ # add hidden states from the last decoder layer
1188
+ if output_hidden_states:
1189
+ all_hidden_states += (hidden_states,)
1190
+
1191
+ next_cache = None
1192
+ if use_cache:
1193
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1194
+ if not return_dict:
1195
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1196
+ return BaseModelOutputWithPast(
1197
+ last_hidden_state=hidden_states,
1198
+ past_key_values=next_cache,
1199
+ hidden_states=all_hidden_states,
1200
+ attentions=all_self_attns,
1201
+ )
1202
+
1203
+
1204
+ class Phi3ForCausalLM(Phi3PreTrainedModel):
1205
+ _tied_weights_keys = ['lm_head.weight']
1206
+
1207
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
1208
+ def __init__(self, config):
1209
+ super().__init__(config)
1210
+ self.model = Phi3Model(config)
1211
+ self.vocab_size = config.vocab_size
1212
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ self.post_init()
1216
+
1217
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1218
+ def get_input_embeddings(self):
1219
+ return self.model.embed_tokens
1220
+
1221
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1222
+ def set_input_embeddings(self, value):
1223
+ self.model.embed_tokens = value
1224
+
1225
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1226
+ def get_output_embeddings(self):
1227
+ return self.lm_head
1228
+
1229
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1230
+ def set_output_embeddings(self, new_embeddings):
1231
+ self.lm_head = new_embeddings
1232
+
1233
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1234
+ def set_decoder(self, decoder):
1235
+ self.model = decoder
1236
+
1237
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1238
+ def get_decoder(self):
1239
+ return self.model
1240
+
1241
+ # Ignore copy
1242
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1243
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1244
+ def forward(
1245
+ self,
1246
+ input_ids: torch.LongTensor = None,
1247
+ attention_mask: Optional[torch.Tensor] = None,
1248
+ position_ids: Optional[torch.LongTensor] = None,
1249
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1250
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1251
+ labels: Optional[torch.LongTensor] = None,
1252
+ use_cache: Optional[bool] = None,
1253
+ output_attentions: Optional[bool] = None,
1254
+ output_hidden_states: Optional[bool] = None,
1255
+ return_dict: Optional[bool] = None,
1256
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1257
+ r"""
1258
+ Args:
1259
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1260
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1261
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1262
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1263
+
1264
+ Returns:
1265
+
1266
+ Example:
1267
+
1268
+ ```python
1269
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
1270
+
1271
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1272
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
1273
+
1274
+ >>> prompt = "This is an example script ."
1275
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1276
+
1277
+ >>> # Generate
1278
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1279
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1280
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
1281
+ ```"""
1282
+
1283
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1284
+ output_hidden_states = (
1285
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1286
+ )
1287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1288
+
1289
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1290
+ outputs = self.model(
1291
+ input_ids=input_ids,
1292
+ attention_mask=attention_mask,
1293
+ position_ids=position_ids,
1294
+ past_key_values=past_key_values,
1295
+ inputs_embeds=inputs_embeds,
1296
+ use_cache=use_cache,
1297
+ output_attentions=output_attentions,
1298
+ output_hidden_states=output_hidden_states,
1299
+ return_dict=return_dict,
1300
+ )
1301
+
1302
+ hidden_states = outputs[0]
1303
+ logits = self.lm_head(hidden_states)
1304
+ logits = logits.float()
1305
+
1306
+ loss = None
1307
+ if labels is not None:
1308
+ # Shift so that tokens < n predict n
1309
+ shift_logits = logits[..., :-1, :].contiguous()
1310
+ shift_labels = labels[..., 1:].contiguous()
1311
+ # Flatten the tokens
1312
+ loss_fct = CrossEntropyLoss()
1313
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1314
+ shift_labels = shift_labels.view(-1)
1315
+ # Enable model parallelism
1316
+ shift_labels = shift_labels.to(shift_logits.device)
1317
+ loss = loss_fct(shift_logits, shift_labels)
1318
+
1319
+ if not return_dict:
1320
+ output = (logits,) + outputs[1:]
1321
+ return (loss,) + output if loss is not None else output
1322
+
1323
+ return CausalLMOutputWithPast(
1324
+ loss=loss,
1325
+ logits=logits,
1326
+ past_key_values=outputs.past_key_values,
1327
+ hidden_states=outputs.hidden_states,
1328
+ attentions=outputs.attentions,
1329
+ )
1330
+
1331
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
1332
+ def prepare_inputs_for_generation(
1333
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1334
+ ):
1335
+ if past_key_values is not None:
1336
+ if isinstance(past_key_values, Cache):
1337
+ cache_length = past_key_values.get_seq_length()
1338
+ past_length = past_key_values.seen_tokens
1339
+ max_cache_length = past_key_values.get_max_length()
1340
+ else:
1341
+ cache_length = past_length = past_key_values[0][0].shape[2]
1342
+ max_cache_length = None
1343
+
1344
+ # Keep only the unprocessed tokens:
1345
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1346
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1347
+ # input)
1348
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1349
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1350
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1351
+ # input_ids based on the past_length.
1352
+ elif past_length < input_ids.shape[1]:
1353
+ input_ids = input_ids[:, past_length:]
1354
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1355
+
1356
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1357
+ if (
1358
+ max_cache_length is not None
1359
+ and attention_mask is not None
1360
+ and cache_length + input_ids.shape[1] > max_cache_length
1361
+ ):
1362
+ attention_mask = attention_mask[:, -max_cache_length:]
1363
+
1364
+ position_ids = kwargs.get('position_ids', None)
1365
+ if attention_mask is not None and position_ids is None:
1366
+ # create position_ids on the fly for batch generation
1367
+ position_ids = attention_mask.long().cumsum(-1) - 1
1368
+ position_ids.masked_fill_(attention_mask == 0, 1)
1369
+ if past_key_values:
1370
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1371
+
1372
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1373
+ if (inputs_embeds is not None and past_key_values is None) or (inputs_embeds is not None and len(past_key_values) == 0):
1374
+ model_inputs = {'inputs_embeds': inputs_embeds}
1375
+ else:
1376
+ model_inputs = {'input_ids': input_ids}
1377
+
1378
+ model_inputs.update(
1379
+ {
1380
+ 'position_ids': position_ids,
1381
+ 'past_key_values': past_key_values,
1382
+ 'use_cache': kwargs.get('use_cache'),
1383
+ 'attention_mask': attention_mask,
1384
+ }
1385
+ )
1386
+ return model_inputs
1387
+
1388
+ @staticmethod
1389
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1390
+ def _reorder_cache(past_key_values, beam_idx):
1391
+ reordered_past = ()
1392
+ for layer_past in past_key_values:
1393
+ reordered_past += (
1394
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1395
+ )
1396
+ return reordered_past
1397
+
1398
+
1399
+ @add_start_docstrings(
1400
+ """
1401
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
1402
+
1403
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1404
+ (e.g. GPT-2) do.
1405
+
1406
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1407
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1408
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1409
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1410
+ each row of the batch).
1411
+ """,
1412
+ PHI3_START_DOCSTRING,
1413
+ )
1414
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
1415
+ class Phi3ForSequenceClassification(Phi3PreTrainedModel):
1416
+ def __init__(self, config):
1417
+ super().__init__(config)
1418
+ self.num_labels = config.num_labels
1419
+ self.model = Phi3Model(config)
1420
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1421
+
1422
+ # Initialize weights and apply final processing
1423
+ self.post_init()
1424
+
1425
+ def get_input_embeddings(self):
1426
+ return self.model.embed_tokens
1427
+
1428
+ def set_input_embeddings(self, value):
1429
+ self.model.embed_tokens = value
1430
+
1431
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1432
+ def forward(
1433
+ self,
1434
+ input_ids: torch.LongTensor = None,
1435
+ attention_mask: Optional[torch.Tensor] = None,
1436
+ position_ids: Optional[torch.LongTensor] = None,
1437
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1438
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1439
+ labels: Optional[torch.LongTensor] = None,
1440
+ use_cache: Optional[bool] = None,
1441
+ output_attentions: Optional[bool] = None,
1442
+ output_hidden_states: Optional[bool] = None,
1443
+ return_dict: Optional[bool] = None,
1444
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1445
+ r"""
1446
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1447
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1448
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1449
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1450
+ """
1451
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1452
+
1453
+ model_outputs = self.model(
1454
+ input_ids,
1455
+ attention_mask=attention_mask,
1456
+ position_ids=position_ids,
1457
+ past_key_values=past_key_values,
1458
+ inputs_embeds=inputs_embeds,
1459
+ use_cache=use_cache,
1460
+ output_attentions=output_attentions,
1461
+ output_hidden_states=output_hidden_states,
1462
+ return_dict=return_dict,
1463
+ )
1464
+ hidden_states = model_outputs[0]
1465
+ logits = self.score(hidden_states)
1466
+
1467
+ if input_ids is not None:
1468
+ batch_size = input_ids.shape[0]
1469
+ else:
1470
+ batch_size = inputs_embeds.shape[0]
1471
+
1472
+ if self.config.pad_token_id is None and batch_size != 1:
1473
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
1474
+ if self.config.pad_token_id is None:
1475
+ sequence_lengths = -1
1476
+ else:
1477
+ if input_ids is not None:
1478
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1479
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1480
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1481
+ sequence_lengths = sequence_lengths.to(logits.device)
1482
+ else:
1483
+ sequence_lengths = -1
1484
+
1485
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1486
+
1487
+ loss = None
1488
+ if labels is not None:
1489
+ labels = labels.to(logits.device)
1490
+ if self.config.problem_type is None:
1491
+ if self.num_labels == 1:
1492
+ self.config.problem_type = 'regression'
1493
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1494
+ self.config.problem_type = 'single_label_classification'
1495
+ else:
1496
+ self.config.problem_type = 'multi_label_classification'
1497
+
1498
+ if self.config.problem_type == 'regression':
1499
+ loss_fct = MSELoss()
1500
+ if self.num_labels == 1:
1501
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1502
+ else:
1503
+ loss = loss_fct(pooled_logits, labels)
1504
+ elif self.config.problem_type == 'single_label_classification':
1505
+ loss_fct = CrossEntropyLoss()
1506
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1507
+ elif self.config.problem_type == 'multi_label_classification':
1508
+ loss_fct = BCEWithLogitsLoss()
1509
+ loss = loss_fct(pooled_logits, labels)
1510
+ if not return_dict:
1511
+ output = (pooled_logits,) + model_outputs[1:]
1512
+ return ((loss,) + output) if loss is not None else output
1513
+
1514
+ return SequenceClassifierOutputWithPast(
1515
+ loss=loss,
1516
+ logits=pooled_logits,
1517
+ past_key_values=model_outputs.past_key_values,
1518
+ hidden_states=model_outputs.hidden_states,
1519
+ attentions=model_outputs.attentions,
1520
+ )
1521
+
1522
+
1523
+ @add_start_docstrings(
1524
+ """
1525
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1526
+ Named-Entity-Recognition (NER) tasks.
1527
+ """,
1528
+ PHI3_START_DOCSTRING,
1529
+ )
1530
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
1531
+ class Phi3ForTokenClassification(Phi3PreTrainedModel):
1532
+ def __init__(self, config: Phi3Config):
1533
+ super().__init__(config)
1534
+ self.num_labels = config.num_labels
1535
+
1536
+ self.model = Phi3Model(config)
1537
+ if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None:
1538
+ classifier_dropout = config.classifier_dropout
1539
+ elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None:
1540
+ classifier_dropout = config.hidden_dropout
1541
+ else:
1542
+ classifier_dropout = 0.1
1543
+ self.dropout = nn.Dropout(classifier_dropout)
1544
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1545
+
1546
+ # Initialize weights and apply final processing
1547
+ self.post_init()
1548
+
1549
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
1550
+ @add_code_sample_docstrings(
1551
+ checkpoint=_CHECKPOINT_FOR_DOC,
1552
+ output_type=TokenClassifierOutput,
1553
+ config_class=_CONFIG_FOR_DOC,
1554
+ )
1555
+ def forward(
1556
+ self,
1557
+ input_ids: Optional[torch.LongTensor] = None,
1558
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1559
+ attention_mask: Optional[torch.Tensor] = None,
1560
+ inputs_embeds: Optional[torch.Tensor] = None,
1561
+ labels: Optional[torch.Tensor] = None,
1562
+ use_cache: Optional[bool] = None,
1563
+ output_attentions: Optional[bool] = None,
1564
+ output_hidden_states: Optional[bool] = None,
1565
+ return_dict: Optional[bool] = None,
1566
+ **deprecated_arguments,
1567
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1568
+ r"""
1569
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1570
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1571
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1572
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1573
+ """
1574
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1575
+
1576
+ model_outputs = self.model(
1577
+ input_ids,
1578
+ past_key_values=past_key_values,
1579
+ attention_mask=attention_mask,
1580
+ inputs_embeds=inputs_embeds,
1581
+ use_cache=use_cache,
1582
+ output_attentions=output_attentions,
1583
+ output_hidden_states=output_hidden_states,
1584
+ return_dict=return_dict,
1585
+ )
1586
+
1587
+ hidden_states = model_outputs[0]
1588
+ hidden_states = self.dropout(hidden_states)
1589
+ logits = self.classifier(hidden_states)
1590
+
1591
+ loss = None
1592
+ if labels is not None:
1593
+ # move labels to correct device to enable model parallelism
1594
+ labels = labels.to(logits.device)
1595
+ batch_size, seq_length = labels.shape
1596
+ loss_fct = CrossEntropyLoss()
1597
+ loss = loss_fct(
1598
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1599
+ )
1600
+
1601
+ if not return_dict:
1602
+ output = (logits,) + model_outputs[2:]
1603
+ return ((loss,) + output) if loss is not None else output
1604
+
1605
+ return TokenClassifierOutput(
1606
+ loss=loss,
1607
+ logits=logits,
1608
+ hidden_states=model_outputs.hidden_states,
1609
+ attentions=model_outputs.attentions,
1610
+ )
modeling_sa2va_chat.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2024 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import warnings
8
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set
9
+ from PIL import Image
10
+ import re
11
+
12
+ import torchvision.transforms as T
13
+ from torchvision.transforms.functional import InterpolationMode
14
+
15
+ import torch.utils.checkpoint
16
+ import transformers
17
+
18
+ from .modeling_internlm2 import InternLM2ForCausalLM
19
+ from .modeling_phi3 import Phi3ForCausalLM
20
+ from peft import LoraConfig, get_peft_model
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
24
+ LlamaTokenizer, Qwen2ForCausalLM)
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import ModelOutput, logging, TensorType
28
+ from transformers import StoppingCriteriaList, StoppingCriteria
29
+ from transformers.models.mask2former.image_processing_mask2former import (
30
+ remove_low_and_no_objects, check_segment_validity)
31
+
32
+ from .configuration_sa2va_chat import Sa2VAChatConfig
33
+ from .modeling_intern_vit import InternVisionModel, has_flash_attn
34
+
35
+ from .templates import PROMPT_TEMPLATE
36
+
37
+ import numpy as np
38
+ from torchvision.transforms.functional import resize, to_pil_image
39
+
40
+ from types import MethodType
41
+ import torch.nn.functional as F
42
+
43
+ from transformers import Mask2FormerForUniversalSegmentation
44
+
45
+ from .mask2former import (
46
+ Mask2FormerMaskedAttentionDecoder_forward_first3layers,
47
+ Mask2FormerMaskedAttentionDecoder_forward_last3layers,
48
+ Mask2FormerTransformerModule_forward_first_part,
49
+ Mask2FormerTransformerModule_forward_second_part,
50
+ Mask2FormerModel_forward_first_part,
51
+ Mask2FormerModel_forward_second_part,
52
+ Mask2FormerForUniversalSegmentation_forward_first_part,
53
+ Mask2FormerForUniversalSegmentation_forward_second_part,
54
+ _post_init,
55
+ ov_class_predictor,
56
+ Mask2FormerLoss_loss_labels,
57
+ Mask2FormerLoss_loss_masks,
58
+ Mask2FormerLoss_sample_points_using_uncertainty,
59
+ Mask2FormerHungarianMatcher_forward,
60
+ )
61
+
62
+ from .constants import (
63
+ IMG_CONTEXT_TOKEN, OBJ_CONTEXT_TOKEN, SEG_TOKEN, CLS_TOKEN, BG_CLS_TOKEN, OBJ_START_TOKEN, OBJ_END_TOKEN)
64
+
65
+
66
+
67
+ try:
68
+ from .flash_attention import FlashAttention
69
+ has_flash_attn = True
70
+ except:
71
+ print('FlashAttention is not installed.')
72
+ has_flash_attn = False
73
+
74
+ logger = logging.get_logger(__name__)
75
+
76
+ def version_cmp(v1, v2, op='eq'):
77
+ import operator
78
+
79
+ from packaging import version
80
+ op_func = getattr(operator, op)
81
+ return op_func(version.parse(v1), version.parse(v2))
82
+
83
+ class StopWordStoppingCriteria(StoppingCriteria):
84
+ """StopWord stopping criteria."""
85
+
86
+ def __init__(self, tokenizer, stop_word):
87
+ self.tokenizer = tokenizer
88
+ self.stop_word = stop_word
89
+ self.length = len(self.stop_word)
90
+
91
+ def __call__(self, input_ids, *args, **kwargs) -> bool:
92
+ cur_text = self.tokenizer.decode(input_ids[0])
93
+ cur_text = cur_text.replace('\r', '').replace('\n', '')
94
+ return cur_text[-self.length:] == self.stop_word
95
+
96
+ def get_stop_criteria(
97
+ tokenizer,
98
+ stop_words=[],
99
+ ):
100
+ stop_criteria = StoppingCriteriaList()
101
+ for word in stop_words:
102
+ stop_criteria.append(StopWordStoppingCriteria(tokenizer, word))
103
+ return stop_criteria
104
+
105
+ class DirectResize:
106
+ def __init__(self, target_length: int) -> None:
107
+ self.target_length = target_length
108
+
109
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
110
+ """
111
+ Expects a numpy array with shape HxWxC in uint8 format.
112
+ """
113
+ img = to_pil_image(image, mode='RGB')
114
+ return np.array(img.resize((self.target_length, self.target_length)))
115
+
116
+ class Sa2VAChatModel(PreTrainedModel):
117
+ config_class = Sa2VAChatConfig
118
+ main_input_name = 'pixel_values'
119
+ base_model_prefix = 'language_model'
120
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
121
+ 'Phi3DecoderLayer', 'Qwen2DecoderLayer', 'Mask2FormerForUniversalSegmentation']
122
+ _supports_flash_attn_2 = True
123
+ supports_gradient_checkpointing = True
124
+
125
+ def __init__(self, config: Sa2VAChatConfig, vision_model=None, language_model=None, mask2former=None, use_flash_attn=True):
126
+ super().__init__(config)
127
+
128
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
129
+ image_size = config.force_image_size or config.vision_config.image_size
130
+ patch_size = config.vision_config.patch_size
131
+ self.patch_size = patch_size
132
+ self.select_layer = config.select_layer
133
+ self.template = config.template
134
+ self.template = self.template.replace('-', '_')
135
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
136
+ self.downsample_ratio = config.downsample_ratio
137
+ self.ps_version = config.ps_version
138
+ self.llm_arch_name = config.llm_config.architectures[0]
139
+
140
+ use_flash_attn = use_flash_attn if has_flash_attn else False
141
+ config.vision_config.use_flash_attn = True if use_flash_attn else False
142
+ config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager'
143
+
144
+ logger.info(f'num_image_token: {self.num_image_token}')
145
+ logger.info(f'ps_version: {self.ps_version}')
146
+ if vision_model is not None:
147
+ self.vision_model = vision_model
148
+ else:
149
+ self.vision_model = InternVisionModel(config.vision_config)
150
+ if language_model is not None:
151
+ self.language_model = language_model
152
+ else:
153
+ if config.llm_config.architectures[0] == 'LlamaForCausalLM':
154
+ self.language_model = LlamaForCausalLM(config.llm_config)
155
+ elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
156
+ self.language_model = InternLM2ForCausalLM(config.llm_config)
157
+ elif config.llm_config.architectures[0] == 'Phi3ForCausalLM':
158
+ self.language_model = Phi3ForCausalLM(config.llm_config)
159
+ elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
160
+ self.language_model = Qwen2ForCausalLM(config.llm_config)
161
+ else:
162
+ raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
163
+
164
+ vit_hidden_size = config.vision_config.hidden_size
165
+ llm_hidden_size = config.llm_config.hidden_size
166
+
167
+ self.mlp1 = nn.Sequential(
168
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
169
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(llm_hidden_size, llm_hidden_size)
172
+ )
173
+
174
+ self.img_context_token_id = None
175
+ self.conv_template = PROMPT_TEMPLATE[self.template]
176
+ self.template = self.conv_template
177
+ if hasattr(config, 'system_message'):
178
+ self.system_message = config.system_message
179
+ self.num_samples = 0
180
+
181
+ if config.use_backbone_lora:
182
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
183
+
184
+ if config.use_llm_lora:
185
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
186
+
187
+ # mask2former
188
+ if mask2former is None:
189
+ self.mask2former = Mask2FormerForUniversalSegmentation(config.m2f_config)
190
+ else:
191
+ self.mask2former = mask2former
192
+ assert self.mask2former.config.num_queries == config.num_m2f_queries
193
+ self.num_m2f_queries =config. num_m2f_queries
194
+ self.num_m2f_proposals = config.num_m2f_proposals
195
+ self.m2f_input_size = 1024
196
+
197
+ # register functions
198
+ self.mask2former._post_init = MethodType(_post_init, self.mask2former)
199
+ self.mask2former.ov_class_predictor = MethodType(ov_class_predictor, self.mask2former)
200
+ self.mask2former.criterion.loss_labels = MethodType(Mask2FormerLoss_loss_labels, self.mask2former.criterion)
201
+ self.mask2former.criterion.loss_masks = MethodType(Mask2FormerLoss_loss_masks, self.mask2former.criterion)
202
+ self.mask2former.criterion.sample_points_using_uncertainty = MethodType(
203
+ Mask2FormerLoss_sample_points_using_uncertainty, self.mask2former.criterion)
204
+ self.mask2former.forward_first_part = MethodType(Mask2FormerForUniversalSegmentation_forward_first_part, self.mask2former)
205
+ self.mask2former.forward_second_part = MethodType(Mask2FormerForUniversalSegmentation_forward_second_part, self.mask2former)
206
+ self.mask2former.model.Mask2FormerModel_forward_first_part = MethodType(
207
+ Mask2FormerModel_forward_first_part, self.mask2former.model)
208
+ self.mask2former.model.Mask2FormerModel_forward_second_part = MethodType(
209
+ Mask2FormerModel_forward_second_part, self.mask2former.model)
210
+ self.mask2former.model.transformer_module.Mask2FormerTransformerModule_forward_first_part = MethodType(
211
+ Mask2FormerTransformerModule_forward_first_part, self.mask2former.model.transformer_module
212
+ )
213
+ self.mask2former.model.transformer_module.Mask2FormerTransformerModule_forward_second_part = MethodType(
214
+ Mask2FormerTransformerModule_forward_second_part, self.mask2former.model.transformer_module
215
+ )
216
+ self.mask2former.model.transformer_module.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers = MethodType(
217
+ Mask2FormerMaskedAttentionDecoder_forward_first3layers, self.mask2former.model.transformer_module.decoder
218
+ )
219
+ self.mask2former.model.transformer_module.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers = MethodType(
220
+ Mask2FormerMaskedAttentionDecoder_forward_last3layers, self.mask2former.model.transformer_module.decoder
221
+ )
222
+ self.mask2former.criterion.matcher.forward = MethodType(Mask2FormerHungarianMatcher_forward, self.mask2former.criterion.matcher)
223
+
224
+ # post_init of mask2former
225
+ self.mask2former._post_init()
226
+
227
+ out_dim = config.m2f_config.hidden_dim
228
+ in_dim = config.llm_config.hidden_size
229
+
230
+ self.m2f_to_llm = nn.Sequential(
231
+ nn.LayerNorm(out_dim,),
232
+ nn.Linear(out_dim, in_dim),
233
+ nn.GELU(),
234
+ nn.Linear(in_dim, in_dim)
235
+ )
236
+
237
+ self.llm_to_m2f = nn.Sequential(
238
+ nn.LayerNorm(in_dim),
239
+ nn.Linear(in_dim, out_dim * 2),
240
+ nn.GELU(),
241
+ nn.Linear(out_dim * 2, out_dim * 2)
242
+ )
243
+
244
+ self.llm_to_cls = nn.Sequential(
245
+ nn.LayerNorm(in_dim),
246
+ nn.Linear(in_dim, out_dim),
247
+ nn.GELU(),
248
+ nn.Linear(out_dim, out_dim)
249
+ )
250
+
251
+ self.init_prediction_config = False
252
+
253
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
254
+ lora_config = LoraConfig(
255
+ r=r,
256
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
257
+ lora_alpha=lora_alpha,
258
+ lora_dropout=lora_dropout,
259
+ )
260
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
261
+ self.vision_model.print_trainable_parameters()
262
+
263
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
264
+ # Determine the target modules based on the architecture of the language model
265
+ if self.llm_arch_name == 'InternLM2ForCausalLM':
266
+ target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
267
+ elif self.llm_arch_name == 'Phi3ForCausalLM':
268
+ target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
269
+ elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
270
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
271
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
272
+ else:
273
+ raise NotImplemented
274
+ lora_config = LoraConfig(
275
+ r=r,
276
+ target_modules=target_modules,
277
+ lora_alpha=lora_alpha,
278
+ lora_dropout=lora_dropout,
279
+ task_type='CAUSAL_LM'
280
+ )
281
+ self.language_model = get_peft_model(self.language_model, lora_config)
282
+ self.language_model.enable_input_require_grads()
283
+ self.language_model.print_trainable_parameters()
284
+
285
+ def pixel_shuffle(self, x, scale_factor=0.5):
286
+ n, w, h, c = x.size()
287
+ # N, W, H, C --> N, W, H * scale, C // scale
288
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
289
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
290
+ x = x.permute(0, 2, 1, 3).contiguous()
291
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
292
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
293
+ int(c / (scale_factor * scale_factor)))
294
+ if self.ps_version == 'v1':
295
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
296
+ 'which results in a transposed image.')
297
+ else:
298
+ x = x.permute(0, 2, 1, 3).contiguous()
299
+ return x
300
+
301
+ def extract_feature(self, pixel_values):
302
+ if self.select_layer == -1:
303
+ vit_embeds = self.vision_model(
304
+ pixel_values=pixel_values,
305
+ output_hidden_states=False,
306
+ return_dict=True).last_hidden_state
307
+ else:
308
+ vit_embeds = self.vision_model(
309
+ pixel_values=pixel_values,
310
+ output_hidden_states=True,
311
+ return_dict=True).hidden_states[self.select_layer]
312
+ vit_embeds = vit_embeds[:, 1:, :]
313
+
314
+ h = w = int(vit_embeds.shape[1] ** 0.5)
315
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
316
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
317
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
318
+ vit_embeds = self.mlp1(vit_embeds)
319
+ return vit_embeds
320
+
321
+ @property
322
+ def lm_head(self):
323
+ return self.language_model.get_output_embeddings()
324
+
325
+ def get_input_embeddings(self):
326
+ return self.language_model.get_input_embeddings()
327
+
328
+ def get_output_embeddings(self):
329
+ return self.language_model.get_output_embeddings()
330
+
331
+ def forward(self, data, data_samples=None, mode='loss'):
332
+ pixel_values = data['pixel_values']
333
+
334
+ if type(pixel_values) is list or pixel_values.ndim == 5:
335
+ if type(pixel_values) is list:
336
+ pixel_values = [
337
+ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
338
+ ]
339
+ # b*n, c, h, w
340
+ concat_images = torch.cat(
341
+ [image.to(self.vision_model.dtype) for image in pixel_values], dim=0)
342
+ else:
343
+ raise NotImplementedError()
344
+
345
+ input_ids = data['input_ids']
346
+ position_ids = data['position_ids']
347
+ attention_mask = data['attention_mask']
348
+ # sum is 0 are text
349
+ image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0
350
+ image_flags = image_flags.long()
351
+
352
+ labels = data['labels']
353
+ use_cache = False
354
+
355
+ if 'vp_overall_mask' not in data.keys():
356
+ vp_overall_mask = None
357
+ else:
358
+ vp_overall_mask = data['vp_overall_mask']
359
+
360
+ if 'prompt_masks' in data.keys():
361
+ prompt_masks = data['prompt_masks']
362
+ else:
363
+ prompt_masks = None
364
+
365
+ outputs = self._llm_forward(
366
+ input_ids=input_ids,
367
+ position_ids=position_ids,
368
+ attention_mask=attention_mask,
369
+ image_flags=image_flags,
370
+ pixel_values=concat_images,
371
+ labels=labels,
372
+ use_cache=use_cache,
373
+ output_hidden_states=True,
374
+ vp_overall_mask=vp_overall_mask,
375
+ prompt_masks=prompt_masks,
376
+ )
377
+
378
+ return outputs
379
+
380
+ def _llm_forward(
381
+ self,
382
+ pixel_values: torch.FloatTensor,
383
+ input_ids: torch.LongTensor = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ position_ids: Optional[torch.LongTensor] = None,
386
+ image_flags: Optional[torch.LongTensor] = None,
387
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
388
+ labels: Optional[torch.LongTensor] = None,
389
+ use_cache: Optional[bool] = None,
390
+ output_attentions: Optional[bool] = None,
391
+ output_hidden_states: Optional[bool] = None,
392
+ return_dict: Optional[bool] = None,
393
+ vp_overall_mask=None,
394
+ prompt_masks=None,
395
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
396
+ return_dict = return_dict if return_dict is not None \
397
+ else self.config.use_return_dict
398
+
399
+ image_flags = image_flags.squeeze(-1)
400
+ # We only added the clone code here to avoid the error.
401
+ input_embeds = self.language_model.get_input_embeddings()(
402
+ input_ids).clone()
403
+
404
+ vit_embeds = self.extract_feature(pixel_values)
405
+ vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16?
406
+ fast_vit_embeds = None
407
+
408
+ vit_embeds = vit_embeds[image_flags == 1]
409
+ vit_batch_size = pixel_values.shape[0]
410
+
411
+ B, N, C = input_embeds.shape
412
+ input_embeds = input_embeds.reshape(B * N, C)
413
+
414
+ self._count += 1
415
+
416
+ if vp_overall_mask is not None and prompt_masks is not None:
417
+ vp_embeds = []
418
+ vp_overall_mask = vp_overall_mask.to(vit_embeds.device).bool()
419
+ prompt_masks = [item.to(vit_embeds.device).bool() for item in prompt_masks]
420
+
421
+ vp_overall_mask = vp_overall_mask[image_flags == 1]
422
+ overall_tile_vit_embeds = vit_embeds[vp_overall_mask] # (n_img, hw, c)
423
+
424
+ i_vp_img = 0
425
+ for i_img in range(len(vit_embeds)):
426
+ vp_embeds.append(vit_embeds[i_img].reshape(-1, C))
427
+ if vp_overall_mask[i_img]:
428
+ tile_vit_embeds = overall_tile_vit_embeds[i_vp_img].reshape(-1, C) # (hw, C)
429
+ objects_prompt_masks = prompt_masks[i_vp_img]
430
+ n_obj = len(objects_prompt_masks)
431
+ tile_vit_embeds = tile_vit_embeds.unsqueeze(0).repeat(n_obj, 1, 1)
432
+ objects_prompt_masks = objects_prompt_masks.reshape(n_obj, -1)
433
+ vp_embeds.append(tile_vit_embeds[objects_prompt_masks])
434
+ i_vp_img += 1
435
+ vp_embeds = torch.cat(vp_embeds, dim=0)
436
+ else:
437
+ vp_embeds = None
438
+
439
+ input_ids = input_ids.reshape(B * N)
440
+ selected = (input_ids == self.img_context_token_id)
441
+
442
+ if vp_embeds is None:
443
+ try:
444
+ input_embeds[selected] = vit_embeds.reshape(-1, C)
445
+ except Exception as e:
446
+ vit_embeds = vit_embeds.reshape(-1, C)
447
+ print(f'warning: {e}, input_embeds[selected].shape='
448
+ f'{input_embeds[selected].shape}, '
449
+ f'vit_embeds.shape={vit_embeds.shape}')
450
+ n_token = selected.sum()
451
+ if n_token > len(vit_embeds):
452
+ print(f"Wrong !!! {n_token} image tokens in text but only {len(vit_embeds)} vit embeds !!!")
453
+ expand_ratio = n_token // len(vit_embeds) + 1
454
+ vit_embeds = torch.cat([vit_embeds] * expand_ratio, dim=0)
455
+
456
+ input_embeds[selected] = vit_embeds[:n_token]
457
+ else:
458
+ try:
459
+ input_embeds[selected] = vp_embeds.reshape(-1, C)
460
+ except Exception as e:
461
+ vp_embeds = vp_embeds.reshape(-1, C)
462
+ print(f'warning: {e}, input_embeds[selected].shape='
463
+ f'{input_embeds[selected].shape}, '
464
+ f'vp_embeds.shape={vp_embeds.shape}')
465
+ n_token = selected.sum()
466
+ if n_token > len(vp_embeds):
467
+ print(f"Wrong !!! {n_token} image tokens in text but only {len(vp_embeds)} vit embeds !!!")
468
+ expand_ratio = n_token // len(vp_embeds) + 1
469
+ vp_embeds = torch.cat([vp_embeds] * expand_ratio, dim=0)
470
+
471
+ input_embeds[selected] = vp_embeds[:n_token]
472
+
473
+ input_embeds = input_embeds.reshape(B, N, C)
474
+
475
+ outputs = self.language_model(
476
+ inputs_embeds=input_embeds,
477
+ attention_mask=attention_mask,
478
+ position_ids=position_ids,
479
+ past_key_values=past_key_values,
480
+ use_cache=use_cache,
481
+ output_attentions=output_attentions,
482
+ output_hidden_states=output_hidden_states,
483
+ return_dict=return_dict,
484
+ )
485
+ logits = outputs.logits
486
+
487
+ loss = None
488
+ if labels is not None:
489
+ # Shift so that tokens < n predict n
490
+ shift_logits = logits[..., :-1, :].contiguous()
491
+ shift_labels = labels[..., 1:].contiguous()
492
+ # Flatten the tokens
493
+ loss_fct = CrossEntropyLoss()
494
+ shift_logits = shift_logits.view(
495
+ -1, self.language_model.config.vocab_size)
496
+ shift_labels = shift_labels.view(-1)
497
+ # Enable model parallelism
498
+ shift_labels = shift_labels.to(shift_logits.device)
499
+ loss = loss_fct(shift_logits, shift_labels)
500
+
501
+ if not return_dict:
502
+ output = (logits,) + outputs[1:]
503
+ return (loss,) + output if loss is not None else output
504
+
505
+ return CausalLMOutputWithPast(
506
+ loss=loss,
507
+ logits=logits,
508
+ past_key_values=outputs.past_key_values,
509
+ hidden_states=outputs.hidden_states,
510
+ attentions=outputs.attentions,
511
+ )
512
+
513
+ @torch.no_grad()
514
+ def generate(
515
+ self,
516
+ pixel_values: Optional[torch.FloatTensor] = None,
517
+ input_ids: Optional[torch.FloatTensor] = None,
518
+ attention_mask: Optional[torch.LongTensor] = None,
519
+ visual_features: Optional[torch.FloatTensor] = None,
520
+ generation_config: Optional[GenerationConfig] = None,
521
+ output_hidden_states: Optional[bool] = None,
522
+ return_dict: Optional[bool] = None,
523
+ prompt_masks=None,
524
+ vp_overall_mask=None,
525
+ query_embeds=None,
526
+ **generate_kwargs,
527
+ ) -> torch.LongTensor:
528
+ device = self.device
529
+ assert self.img_context_token_id is not None
530
+
531
+ if pixel_values is not None:
532
+ if visual_features is not None:
533
+ vit_embeds = visual_features
534
+ else:
535
+ if type(pixel_values) is list or pixel_values.ndim == 5:
536
+ if type(pixel_values) is list:
537
+ pixel_values = [
538
+ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values
539
+ ]
540
+ # b*n, c, h, w
541
+ pixel_values = torch.cat(
542
+ [image.to(self.vision_model.dtype) for image in pixel_values], dim=0)
543
+
544
+ vit_embeds = self.extract_feature(pixel_values.to(device))
545
+ image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0
546
+ image_flags = image_flags.long()
547
+ vit_embeds = vit_embeds[image_flags == 1]
548
+
549
+ input_embeds = self.language_model.get_input_embeddings()(input_ids.to(device))
550
+ B, N, C = input_embeds.shape
551
+ input_embeds = input_embeds.reshape(B * N, C)
552
+
553
+ input_ids = input_ids.reshape(B * N)
554
+ selected = (input_ids == self.img_context_token_id)
555
+ assert selected.sum() != 0
556
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
557
+
558
+ # object queries
559
+ query_embeds = query_embeds.to(input_embeds.dtype)
560
+ selected = (input_ids == self.obj_context_token_id)
561
+ input_embeds[selected] = query_embeds.reshape(-1, C)
562
+
563
+ input_embeds = input_embeds.reshape(B, N, C)
564
+ else:
565
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
566
+
567
+ outputs = self.language_model.generate(
568
+ inputs_embeds=input_embeds,
569
+ attention_mask=attention_mask.to(device),
570
+ generation_config=generation_config,
571
+ output_hidden_states=output_hidden_states,
572
+ # return_dict=return_dict,
573
+ use_cache=True,
574
+ **generate_kwargs,
575
+ )
576
+
577
+ return outputs
578
+
579
+ def preparing_for_generation(self, tokenizer, max_new_tokens=2048, torch_dtype=torch.bfloat16):
580
+ # set stop criteria and generation configs for model
581
+ if not hasattr(self, 'tokenizer'):
582
+ self.tokenizer = tokenizer
583
+ self.bot_name = 'BOT'
584
+ stop_words = []
585
+ stop_words += self.template.get('STOP_WORDS', [])
586
+ stop_criteria = get_stop_criteria(
587
+ tokenizer=self.tokenizer, stop_words=stop_words)
588
+ self.stop_criteria = stop_criteria
589
+
590
+ default_generation_kwargs = dict(
591
+ max_new_tokens=max_new_tokens,
592
+ do_sample=False,
593
+ eos_token_id=self.tokenizer.eos_token_id,
594
+ pad_token_id=(
595
+ self.tokenizer.pad_token_id
596
+ if self.tokenizer.pad_token_id is not None
597
+ else self.tokenizer.eos_token_id
598
+ ),
599
+ )
600
+
601
+ self.gen_config = GenerationConfig(**default_generation_kwargs)
602
+ self.init_prediction_config = True
603
+ self.torch_dtype = torch_dtype
604
+ self.to(torch_dtype)
605
+ self.extra_image_processor = DirectResize(target_length=1024, )
606
+ # for multi image process
607
+ self.min_dynamic_patch = 1
608
+ self.max_dynamic_patch = 12
609
+ self.downsample_ratio = 0.5
610
+ self.image_size = 448
611
+ self.use_thumbnail = True
612
+ patch_size = 14
613
+ self.patch_size = patch_size
614
+
615
+ self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
616
+ self.IMAGENET_MEAN = (0.485, 0.456, 0.406)
617
+ self.IMAGENET_STD = (0.229, 0.224, 0.225)
618
+ self.IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
619
+ self.IMG_START_TOKEN = '<img>'
620
+ self.IMG_END_TOKEN = '</img>'
621
+
622
+ self.transformer = T.Compose([
623
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
624
+ T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
625
+ T.ToTensor(),
626
+ T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
627
+ ])
628
+
629
+ # change phi3 prepare for generation fuction
630
+ if self.config.llm_config.architectures[0] == 'Phi3ForCausalLM':
631
+ self.language_model.prepare_inputs_for_generation = MethodType(prepare_inputs_for_generation_phi3, self.language_model)
632
+
633
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
634
+ self.img_context_token_id = img_context_token_id
635
+ obj_context_token_id = tokenizer.convert_tokens_to_ids(OBJ_CONTEXT_TOKEN)
636
+ self.obj_context_token_id = obj_context_token_id
637
+
638
+ self.PROPOSAL_TOKENS = [SEG_TOKEN.format(id=str(i).zfill(3)) for i in range(self.num_m2f_proposals)]
639
+ self.the_first_seg_token_idx = self.tokenizer(self.PROPOSAL_TOKENS[0], add_special_tokens=False).input_ids[0]
640
+ self.the_last_seg_token_idx = self.tokenizer(self.PROPOSAL_TOKENS[-1], add_special_tokens=False).input_ids[0]
641
+ self.cls_token_idx = self.tokenizer(CLS_TOKEN, add_special_tokens=False).input_ids[0]
642
+ self.bg_cls_token_idx = self.tokenizer(BG_CLS_TOKEN, add_special_tokens=False).input_ids[0]
643
+
644
+ return
645
+
646
+ def predict_forward(
647
+ self,
648
+ image=None,
649
+ video=None,
650
+ text=None,
651
+ past_text='',
652
+ mask_prompts=None,
653
+ tokenizer=None,
654
+ m2f_processor=None,
655
+ ):
656
+ if not self.init_prediction_config:
657
+ assert tokenizer
658
+ self.preparing_for_generation(tokenizer=tokenizer)
659
+
660
+ if image is None and video is None and '<image>' not in past_text:
661
+ text = text.replace('<image>', "")
662
+ input_text = ''
663
+ input_text += self.template['INSTRUCTION'].format(
664
+ input=text, round=1, bot_name=self.bot_name)
665
+ input_text = past_text + input_text
666
+ ids = self.tokenizer.encode(input_text)
667
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
668
+
669
+ attention_mask = torch.ones_like(ids, dtype=torch.bool)
670
+
671
+ mm_inputs = {
672
+ 'pixel_values': None,
673
+ 'input_ids': ids,
674
+ 'attention_mask': attention_mask,
675
+ 'position_ids': None,
676
+ 'past_key_values': None,
677
+ 'labels': None,
678
+ 'prompt_masks': None,
679
+ 'vp_overall_mask': None,
680
+ 'm2f_inputs': None,
681
+ }
682
+ else:
683
+ input_dict = {}
684
+ if video is not None:
685
+ pixel_values = []
686
+ ori_image_size = video[0].size
687
+ for frame_idx, frame_image in enumerate(video):
688
+ assert ori_image_size == frame_image.size
689
+ img = self.transformer(frame_image)
690
+ pixel_values.append(img)
691
+
692
+ pixel_values = torch.stack(pixel_values, dim=0).to(self.torch_dtype) # (n_f, 3, h, w)
693
+ num_image_tokens = self.patch_token
694
+ num_frames = len(pixel_values)
695
+
696
+ # prepapre mask2former inputs
697
+ m2f_pixel_values, m2f_pixel_masks = [], []
698
+ for frame_idx, frame_image in enumerate(video):
699
+ assert ori_image_size == frame_image.size
700
+ w, h = frame_image.size
701
+ if w > h:
702
+ target_size = (self.m2f_input_size, int(h/w*self.m2f_input_size))
703
+ else:
704
+ target_size = (int(w/h*self.m2f_input_size), self.m2f_input_size)
705
+
706
+ resized_frame_image = frame_image.resize(target_size)
707
+ cur_w, cur_h = resized_frame_image.size
708
+ padded_frame_image = np.ones(shape=(self.m2f_input_size, self.m2f_input_size, 3), dtype=np.uint8) * 255
709
+ padded_frame_image[:cur_h, :cur_w, :] = np.array(resized_frame_image)
710
+ m2f_inputs_i = m2f_processor(images=Image.fromarray(padded_frame_image), return_tensors="pt", do_resize=False)
711
+ m2f_pixel_values.append(m2f_inputs_i['pixel_values'])
712
+ m2f_pixel_masks.append(m2f_inputs_i['pixel_mask'])
713
+ m2f_inputs = {
714
+ 'pixel_values': torch.cat(m2f_pixel_values, dim=0),
715
+ 'pixel_mask': torch.cat(m2f_pixel_masks, dim=0)}
716
+ else:
717
+ ori_image_size = image.size
718
+
719
+ images = dynamic_preprocess(image, self.min_dynamic_patch,
720
+ self.max_dynamic_patch,
721
+ self.image_size, self.use_thumbnail)
722
+
723
+ pixel_values = [self.transformer(patch) for patch in images]
724
+ pixel_values = torch.stack(pixel_values).to(self.torch_dtype)
725
+ num_image_tokens = pixel_values.shape[0] * self.patch_token
726
+ num_frames = 1
727
+
728
+ w, h = image.size
729
+ if w > h:
730
+ target_size = (self.m2f_input_size, int(h/w*self.m2f_input_size))
731
+ else:
732
+ target_size = (int(w/h*self.m2f_input_size), self.m2f_input_size)
733
+
734
+ resized_image = image.resize(target_size)
735
+ cur_w, cur_h = resized_image.size
736
+ padded_image = np.ones(shape=(self.m2f_input_size, self.m2f_input_size, 3), dtype=np.uint8) * 255
737
+ padded_image[:cur_h, :cur_w, :] = np.array(resized_image)
738
+ m2f_inputs = m2f_processor(images=Image.fromarray(padded_image), return_tensors="pt", do_resize=False)
739
+
740
+ input_dict['pixel_values'] = pixel_values
741
+
742
+ #TODO add a frame tag to indicate the order
743
+ image_token_str = f'{self.IMG_START_TOKEN}' \
744
+ f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
745
+ f'{self.IMG_END_TOKEN}'
746
+ object_token_str = f"{OBJ_START_TOKEN}"\
747
+ f"{OBJ_CONTEXT_TOKEN * self.num_m2f_queries}"\
748
+ f"{OBJ_END_TOKEN}"
749
+ image_token_str = image_token_str + '\n' + object_token_str + '\n'
750
+ image_token_str = image_token_str * num_frames
751
+ image_token_str = image_token_str.strip()
752
+
753
+ if '<image>' in text or mask_prompts is not None:
754
+ assert past_text is None or len(past_text) == 0
755
+ text = text.replace('<image>', image_token_str)
756
+ input_text = ''
757
+ input_text += self.template['INSTRUCTION'].format(
758
+ input=text, round=1, bot_name=self.bot_name)
759
+ input_text = past_text + input_text
760
+ ids = self.tokenizer.encode(input_text)
761
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
762
+
763
+ attention_mask = torch.ones_like(ids, dtype=torch.bool)
764
+
765
+ # encode multi-scale visual features into 100~300 queries
766
+ m2f_inputs['pixel_values'] = m2f_inputs['pixel_values'].to(self.mask2former.dtype).to(self.mask2former.device)
767
+ m2f_inputs['pixel_mask'] = m2f_inputs['pixel_mask'].to(self.mask2former.dtype).to(self.mask2former.device)
768
+ query_features, pixel_level_module_output = \
769
+ self.mask2former.forward_first_part(**m2f_inputs)
770
+ query_embeds = self.m2f_to_llm(query_features) # BS, m2f_NQ, 2048
771
+
772
+
773
+ mm_inputs = {
774
+ 'pixel_values': input_dict['pixel_values'],
775
+ 'input_ids': ids,
776
+ 'attention_mask': attention_mask,
777
+ 'position_ids': None,
778
+ 'past_key_values': None,
779
+ 'labels': None,
780
+ 'query_embeds': query_embeds,
781
+ # 'prompt_masks': mask_prompts,
782
+ # 'vp_overall_mask': input_dict['vp_overall_mask'],
783
+ }
784
+
785
+ generate_output = self.generate(
786
+ **mm_inputs,
787
+ generation_config=self.gen_config,
788
+ streamer=None,
789
+ bos_token_id=self.tokenizer.bos_token_id,
790
+ stopping_criteria=self.stop_criteria,
791
+ output_hidden_states=True,
792
+ return_dict_in_generate=True
793
+ )
794
+ predict = self.tokenizer.decode(
795
+ generate_output.sequences[0], skip_special_tokens=False).strip()
796
+
797
+ ret_masks = []
798
+ if image is None and video is None and '<image>' not in past_text:
799
+ return {'prediction': predict, 'prediction_masks': ret_masks, 'm2f_outputs': None}
800
+
801
+ # if have seg result, find the seg hidden states
802
+ hidden_states = generate_output.hidden_states
803
+ last_hidden_states = [item[-1][0] for item in hidden_states]
804
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
805
+
806
+ # get cls tokens
807
+ bg_cls_token_id = torch.as_tensor([self.bg_cls_token_idx,], dtype=ids.dtype, device=ids.device)
808
+ bg_cls_embedding = self.language_model.get_input_embeddings()(bg_cls_token_id).clone()
809
+ output_ids = generate_output.sequences[0][:-1]
810
+ cls_token_mask = ids == self.cls_token_idx
811
+
812
+ # get seg tokens
813
+ seg_token_mask = (output_ids >= self.the_first_seg_token_idx) & (output_ids <= self.the_last_seg_token_idx)
814
+
815
+ do_pano_seg = torch.any(cls_token_mask) & torch.any(seg_token_mask)
816
+
817
+ reason_cls_token_mask = output_ids == self.cls_token_idx
818
+
819
+ do_reason_seg = torch.any(reason_cls_token_mask) & torch.any(seg_token_mask)
820
+
821
+ if not do_pano_seg and not do_reason_seg:
822
+ return {'prediction': predict, 'prediction_masks': ret_masks, 'm2f_outputs': None}
823
+
824
+ # get seg tokens
825
+ seg_hidden_states = last_hidden_states[-len(seg_token_mask):][seg_token_mask].unsqueeze(0)
826
+ seg_hidden_states = self.llm_to_m2f(seg_hidden_states)
827
+
828
+ if do_pano_seg:
829
+ cls_hidden_states = last_hidden_states[:len(cls_token_mask)][cls_token_mask]
830
+ text_classifier = self.llm_to_cls(torch.cat([cls_hidden_states, bg_cls_embedding], dim=0))
831
+ seg_hidden_states = seg_hidden_states.transpose(0, 1)
832
+
833
+ # proposals go through mask2former decoder layers
834
+ m2f_outputs = self.mask2former.forward_second_part(
835
+ query_features=seg_hidden_states[:, :, :self.mask2former.config.hidden_dim], # q, b, c
836
+ query_embeddings=seg_hidden_states[:, :, self.mask2former.config.hidden_dim:], # q, b, c
837
+ pixel_level_module_output=pixel_level_module_output,
838
+ text_classifier=[text_classifier, ],
839
+ mask_labels=None,
840
+ class_labels=None,
841
+ **m2f_inputs
842
+ )
843
+
844
+ tags = re.findall(r'<p>(.*?)</p>', predict)
845
+ label_id_to_text = {id: tag for id, tag in enumerate(tags)}
846
+
847
+ class_queries_logits = m2f_outputs.class_queries_logits
848
+ masks_queries_logits = m2f_outputs.masks_queries_logits
849
+
850
+ m2f_masks = {'label_id_to_text': label_id_to_text,
851
+ 'class_queries_logits': class_queries_logits,
852
+ 'masks_queries_logits': masks_queries_logits}
853
+
854
+ return {'prediction': predict, 'prediction_masks': ret_masks, 'm2f_outputs': m2f_masks}
855
+ elif do_reason_seg:
856
+ raise NotImplementedError
857
+ else:
858
+ raise NotImplementedError
859
+
860
+ def post_process_panoptic_segmentation(
861
+ self,
862
+ class_queries_logits,
863
+ masks_queries_logits,
864
+ threshold: float = 0.5,
865
+ mask_threshold: float = 0.5,
866
+ overlap_mask_area_threshold: float = 0.8,
867
+ label_ids_to_fuse: Optional[Set[int]] = None,
868
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
869
+ ) -> List[Dict]:
870
+
871
+ if label_ids_to_fuse is None:
872
+ logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
873
+ label_ids_to_fuse = set()
874
+
875
+ batch_size = len(class_queries_logits)
876
+
877
+ # Loop over items in batch size
878
+ results: List[Dict[str, TensorType]] = []
879
+
880
+ for i in range(batch_size):
881
+ height, width = target_sizes[i]
882
+ long_edge = height if height > width else width
883
+ masks_queries_logits_i = torch.nn.functional.interpolate(
884
+ masks_queries_logits[i:i+1], size=(long_edge, long_edge), mode="bilinear", align_corners=False
885
+ )
886
+
887
+ mask_probs = masks_queries_logits_i[0].sigmoid()
888
+
889
+ num_labels = class_queries_logits[i].shape[-1] - 1
890
+
891
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits[i], dim=-1).max(-1)
892
+
893
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
894
+ mask_probs, pred_scores, pred_labels, threshold, num_labels
895
+ )
896
+
897
+ # No mask found
898
+ if mask_probs_item.shape[0] <= 0:
899
+ segmentation = torch.zeros((height, width)) - 1
900
+ results.append({"segmentation": segmentation, "segments_info": []})
901
+ continue
902
+
903
+ # Get segmentation map and segment information of batch item
904
+ target_size = target_sizes[i] if target_sizes is not None else None
905
+ segmentation, segments = compute_segments(
906
+ mask_probs=mask_probs_item,
907
+ pred_scores=pred_scores_item,
908
+ pred_labels=pred_labels_item,
909
+ mask_threshold=mask_threshold,
910
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
911
+ label_ids_to_fuse=label_ids_to_fuse,
912
+ target_size=target_size,
913
+ )
914
+
915
+ results.append({"segmentation": segmentation, "segments_info": segments})
916
+
917
+ return results
918
+
919
+ def get_seg_hidden_states(hidden_states, output_ids, seg_id):
920
+ seg_mask = output_ids == seg_id
921
+ n_out = len(seg_mask)
922
+ if n_out == 0:
923
+ return hidden_states[0:0]
924
+ return hidden_states[-n_out:][seg_mask]
925
+
926
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
927
+ image_size):
928
+ best_ratio_diff = float('inf')
929
+ best_ratio = (1, 1)
930
+ area = width * height
931
+ for ratio in target_ratios:
932
+ target_aspect_ratio = ratio[0] / ratio[1]
933
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
934
+ if ratio_diff < best_ratio_diff:
935
+ best_ratio_diff = ratio_diff
936
+ best_ratio = ratio
937
+ elif ratio_diff == best_ratio_diff:
938
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
939
+ best_ratio = ratio
940
+ return best_ratio
941
+
942
+ def dynamic_preprocess(image,
943
+ min_num=1,
944
+ max_num=6,
945
+ image_size=448,
946
+ use_thumbnail=False):
947
+ orig_width, orig_height = image.size
948
+ aspect_ratio = orig_width / orig_height
949
+
950
+ # calculate the existing image aspect ratio
951
+ target_ratios = {(i, j)
952
+ for n in range(min_num, max_num + 1)
953
+ for i in range(1, n + 1) for j in range(1, n + 1)
954
+ if i * j <= max_num and i * j >= min_num}
955
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
956
+
957
+ # find the closest aspect ratio to the target
958
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
959
+ target_ratios, orig_width,
960
+ orig_height, image_size)
961
+
962
+ # calculate the target width and height
963
+ target_width = image_size * target_aspect_ratio[0]
964
+ target_height = image_size * target_aspect_ratio[1]
965
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
966
+
967
+ # resize the image
968
+ resized_img = image.resize((target_width, target_height))
969
+ processed_images = []
970
+ for i in range(blocks):
971
+ box = ((i % (target_width // image_size)) * image_size,
972
+ (i // (target_width // image_size)) * image_size,
973
+ ((i % (target_width // image_size)) + 1) * image_size,
974
+ ((i // (target_width // image_size)) + 1) * image_size)
975
+ # split the image
976
+ split_img = resized_img.crop(box)
977
+ processed_images.append(split_img)
978
+ assert len(processed_images) == blocks
979
+ if use_thumbnail and len(processed_images) != 1:
980
+ thumbnail_img = image.resize((image_size, image_size))
981
+ processed_images.append(thumbnail_img)
982
+ return processed_images
983
+
984
+
985
+ from transformers.cache_utils import Cache, DynamicCache
986
+
987
+ def prepare_inputs_for_generation_phi3(
988
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
989
+ ):
990
+ if past_key_values is not None:
991
+ if isinstance(past_key_values, Cache):
992
+ cache_length = past_key_values.get_seq_length()
993
+ past_length = past_key_values.seen_tokens
994
+ max_cache_length = past_key_values.get_max_length()
995
+ else:
996
+ cache_length = past_length = past_key_values[0][0].shape[2]
997
+ max_cache_length = None
998
+
999
+ # Keep only the unprocessed tokens:
1000
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1001
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1002
+ # input)
1003
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1004
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1005
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1006
+ # input_ids based on the past_length.
1007
+ elif past_length < input_ids.shape[1]:
1008
+ input_ids = input_ids[:, past_length:]
1009
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1010
+
1011
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1012
+ if (
1013
+ max_cache_length is not None
1014
+ and attention_mask is not None
1015
+ and cache_length + input_ids.shape[1] > max_cache_length
1016
+ ):
1017
+ attention_mask = attention_mask[:, -max_cache_length:]
1018
+
1019
+ position_ids = kwargs.get('position_ids', None)
1020
+ if attention_mask is not None and position_ids is None:
1021
+ # create position_ids on the fly for batch generation
1022
+ position_ids = attention_mask.long().cumsum(-1) - 1
1023
+ position_ids.masked_fill_(attention_mask == 0, 1)
1024
+ if past_key_values:
1025
+ position_ids = position_ids[:, -input_ids.shape[1]:]
1026
+
1027
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1028
+ if inputs_embeds is not None and (past_key_values is None or len(past_key_values)==0):
1029
+ model_inputs = {'inputs_embeds': inputs_embeds}
1030
+ else:
1031
+ model_inputs = {'input_ids': input_ids}
1032
+
1033
+ model_inputs.update(
1034
+ {
1035
+ 'position_ids': position_ids,
1036
+ 'past_key_values': past_key_values,
1037
+ 'use_cache': kwargs.get('use_cache'),
1038
+ 'attention_mask': attention_mask,
1039
+ }
1040
+ )
1041
+ return model_inputs
1042
+
1043
+
1044
+ # Copied from transformers.models.detr.image_processing_detr.compute_segments
1045
+ def compute_segments(
1046
+ mask_probs,
1047
+ pred_scores,
1048
+ pred_labels,
1049
+ mask_threshold: float = 0.5,
1050
+ overlap_mask_area_threshold: float = 0.8,
1051
+ label_ids_to_fuse: Optional[Set[int]] = None,
1052
+ target_size: Tuple[int, int] = None,
1053
+ ):
1054
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
1055
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
1056
+
1057
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
1058
+ segments: List[Dict] = []
1059
+
1060
+ if target_size is not None:
1061
+ mask_probs = mask_probs[..., :height, :width]
1062
+
1063
+ current_segment_id = 0
1064
+
1065
+ # Weigh each mask by its prediction score
1066
+ mask_probs *= pred_scores.view(-1, 1, 1)
1067
+ mask_labels = mask_probs.argmax(0) # [height, width]
1068
+
1069
+ # Keep track of instances of each class
1070
+ stuff_memory_list: Dict[str, int] = {}
1071
+ for k in range(pred_labels.shape[0]):
1072
+ pred_class = pred_labels[k].item()
1073
+ should_fuse = pred_class in label_ids_to_fuse
1074
+
1075
+ # Check if mask exists and large enough to be a segment
1076
+ mask_exists, mask_k = check_segment_validity(
1077
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
1078
+ )
1079
+
1080
+ if mask_exists:
1081
+ if pred_class in stuff_memory_list:
1082
+ current_segment_id = stuff_memory_list[pred_class]
1083
+ else:
1084
+ current_segment_id += 1
1085
+
1086
+ # Add current object segment to final segmentation map
1087
+ segmentation[mask_k] = current_segment_id
1088
+ segment_score = round(pred_scores[k].item(), 6)
1089
+ segments.append(
1090
+ {
1091
+ "id": current_segment_id,
1092
+ "label_id": pred_class,
1093
+ "was_fused": should_fuse,
1094
+ "score": segment_score,
1095
+ }
1096
+ )
1097
+ if should_fuse:
1098
+ stuff_memory_list[pred_class] = current_segment_id
1099
+
1100
+ return segmentation, segments
sam2.py ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "<img>",
17
+ "</img>",
18
+ "<IMG_CONTEXT>",
19
+ "<quad>",
20
+ "</quad>",
21
+ "<ref>",
22
+ "</ref>",
23
+ "<box>",
24
+ "</box>"
25
+ ],
26
+ "eos_token": {
27
+ "content": "<|im_end|>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ },
33
+ "pad_token": {
34
+ "content": "<|endoftext|>",
35
+ "lstrip": false,
36
+ "normalized": false,
37
+ "rstrip": false,
38
+ "single_word": false
39
+ }
40
+ }
templates.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ PROMPT_TEMPLATE = dict(
3
+ default=dict(
4
+ SYSTEM='<|System|>:{system}\n',
5
+ INSTRUCTION='<|User|>:{input}\n<|Bot|>:',
6
+ SEP='\n'),
7
+ zephyr=dict(
8
+ SYSTEM='<|system|>\n{system}\n',
9
+ INSTRUCTION='<|user|>\n{input}\n<|assistant|>\n',
10
+ SEP='\n'),
11
+ internlm_chat=dict(
12
+ SYSTEM='<|System|>:{system}\n',
13
+ INSTRUCTION='<|User|>:{input}<eoh>\n<|Bot|>:',
14
+ SUFFIX='<eoa>',
15
+ SUFFIX_AS_EOS=True,
16
+ SEP='\n',
17
+ STOP_WORDS=['<eoa>']),
18
+ internlm2_chat=dict(
19
+ SYSTEM='<|im_start|>system\n{system}<|im_end|>\n',
20
+ INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
21
+ '<|im_start|>assistant\n'),
22
+ SUFFIX='<|im_end|>',
23
+ SUFFIX_AS_EOS=True,
24
+ SEP='\n',
25
+ STOP_WORDS=['<|im_end|>']),
26
+ moss_sft=dict(
27
+ SYSTEM='{system}\n',
28
+ INSTRUCTION='<|Human|>: {input}<eoh>\n',
29
+ SEP='\n',
30
+ STOP_WORDS=['<eoc>', '<eom>']),
31
+ llama2_chat=dict(
32
+ SYSTEM=(
33
+ '[INST] <<SYS>>\n You are a helpful, respectful and honest '
34
+ 'assistant. Always answer as helpfully as possible, while being '
35
+ 'safe. Your answers should not include any harmful, unethical, '
36
+ 'racist, sexist, toxic, dangerous, or illegal content. Please '
37
+ 'ensure that your responses are socially unbiased and positive in '
38
+ 'nature.\n{system}\n<</SYS>>\n [/INST] '),
39
+ INSTRUCTION='[INST] {input} [/INST]',
40
+ SEP='\n'),
41
+ code_llama_chat=dict(
42
+ SYSTEM='{system}\n', INSTRUCTION='[INST] {input} [/INST]'),
43
+ chatglm2=dict(
44
+ SYSTEM='{system}\n',
45
+ INSTRUCTION='[Round {round}]\n\n问:{input}\n\n答:',
46
+ SEP='\n\n'),
47
+ chatglm3=dict(
48
+ SYSTEM='<|system|>\n{system}',
49
+ INSTRUCTION='<|user|>\n{input}<|assistant|>\n',
50
+ SEP='\n'),
51
+ qwen_chat=dict(
52
+ SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'),
53
+ INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
54
+ '<|im_start|>assistant\n'),
55
+ SUFFIX='<|im_end|>',
56
+ SUFFIX_AS_EOS=True,
57
+ SEP='\n',
58
+ STOP_WORDS=['<|im_end|>', '<|endoftext|>']),
59
+ baichuan_chat=dict(
60
+ SYSTEM='{system}\n',
61
+ INSTRUCTION='<reserved_102>{input}<reserved_103>',
62
+ SEP='\n'),
63
+ baichuan2_chat=dict(
64
+ SYSTEM='{system}\n',
65
+ INSTRUCTION='<reserved_106>{input}<reserved_107>',
66
+ SEP='\n'),
67
+ wizardlm=dict(
68
+ SYSTEM=('A chat between a curious user and an artificial '
69
+ 'intelligence assistant. The assistant gives '
70
+ 'helpful, detailed, and polite answers to the '
71
+ 'user\'s questions. {system}\n '),
72
+ INSTRUCTION=('USER: {input} ASSISTANT:'),
73
+ SEP='\n'),
74
+ wizardcoder=dict(
75
+ SYSTEM=(
76
+ 'Below is an instruction that describes a task. '
77
+ 'Write a response that appropriately completes the request.\n\n'
78
+ '{system}\n '),
79
+ INSTRUCTION=('### Instruction:\n{input}\n\n### Response:'),
80
+ SEP='\n\n'),
81
+ vicuna=dict(
82
+ SYSTEM=('A chat between a curious user and an artificial '
83
+ 'intelligence assistant. The assistant gives '
84
+ 'helpful, detailed, and polite answers to the '
85
+ 'user\'s questions. {system}\n '),
86
+ INSTRUCTION=('USER: {input} ASSISTANT:'),
87
+ SEP='\n'),
88
+ deepseek_coder=dict(
89
+ SYSTEM=('You are an AI programming assistant, utilizing '
90
+ 'the DeepSeek Coder model, developed by DeepSeek'
91
+ 'Company, and you only answer questions related '
92
+ 'to computer science. For politically sensitive '
93
+ 'questions, security and privacy issues, and '
94
+ 'other non-computer science questions, you will '
95
+ 'refuse to answer. {system}\n'),
96
+ INSTRUCTION=('### Instruction:\n{input}\n### Response:\n'),
97
+ SEP='\n'),
98
+ # TODO: deprecation, v0.2.0
99
+ deepseekcoder=dict(
100
+ SYSTEM=('You are an AI programming assistant, utilizing '
101
+ 'the DeepSeek Coder model, developed by DeepSeek'
102
+ 'Company, and you only answer questions related '
103
+ 'to computer science. For politically sensitive '
104
+ 'questions, security and privacy issues, and '
105
+ 'other non-computer science questions, you will '
106
+ 'refuse to answer. {system}\n'),
107
+ INSTRUCTION=('### Instruction:\n{input}\n### Response:\n'),
108
+ SEP='\n'),
109
+ deepseek_moe=dict(
110
+ SYSTEM=('[INST] {system} [/INST]\n'),
111
+ INSTRUCTION=('[INST] {input} [/INST]'),
112
+ SEP='\n'),
113
+ deepseek_v2=dict(
114
+ SYSTEM='{system}\n\n',
115
+ INSTRUCTION='User: {input}\n\nAssistant: ',
116
+ SUFFIX='<|end▁of▁sentence|>',
117
+ SUFFIX_AS_EOS=True,
118
+ STOP_WORDS=['<|end▁of▁sentence|>']),
119
+ mistral=dict(
120
+ SYSTEM=('[INST] {system} [/INST]\n'),
121
+ INSTRUCTION=('[INST] {input} [/INST]'),
122
+ SEP='\n'),
123
+ mixtral=dict(
124
+ SYSTEM=('[INST] {system} [/INST]\n'),
125
+ INSTRUCTION=('[INST] {input} [/INST]'),
126
+ SEP='\n'),
127
+ minicpm=dict(INSTRUCTION=('<用户> {input} <AI>'), SEP='\n'),
128
+ minicpm3=dict(
129
+ SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'),
130
+ INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
131
+ '<|im_start|>assistant\n'),
132
+ SUFFIX='<|im_end|>',
133
+ SUFFIX_AS_EOS=True,
134
+ SEP='\n',
135
+ STOP_WORDS=['<|im_end|>', '<|endoftext|>']),
136
+ gemma=dict(
137
+ # `system` field is extended by xtuner
138
+ SYSTEM=('<start_of_turn>system\n{system}<end_of_turn>\n'),
139
+ INSTRUCTION=('<start_of_turn>user\n{input}<end_of_turn>\n'
140
+ '<start_of_turn>model\n'),
141
+ SUFFIX='<end_of_turn>',
142
+ SUFFIX_AS_EOS=False,
143
+ SEP='\n',
144
+ STOP_WORDS=['<end_of_turn>']),
145
+ cohere_chat=dict(
146
+ SYSTEM=('<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}'
147
+ '<|END_OF_TURN_TOKEN|>'),
148
+ INSTRUCTION=(
149
+ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{input}<|END_OF_TURN_TOKEN|>'
150
+ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'),
151
+ SUFFIX='<|END_OF_TURN_TOKEN|>',
152
+ SUFFIX_AS_EOS=True,
153
+ STOP_WORDS=['<|END_OF_TURN_TOKEN|>']),
154
+ llama3_chat=dict(
155
+ SYSTEM=('<|start_header_id|>system<|end_header_id|>\n\n'
156
+ '{system}<|eot_id|>'),
157
+ INSTRUCTION=(
158
+ '<|start_header_id|>user<|end_header_id|>\n\n{input}<|eot_id|>'
159
+ '<|start_header_id|>assistant<|end_header_id|>\n\n'),
160
+ SUFFIX='<|eot_id|>',
161
+ SUFFIX_AS_EOS=True,
162
+ STOP_WORDS=['<|eot_id|>']),
163
+ phi3_chat=dict(
164
+ SYSTEM='<|system|>\n{system}<|end|>\n',
165
+ INSTRUCTION='<|user|>\n{input}<|end|>\n<|assistant|>\n',
166
+ SUFFIX='<|end|>',
167
+ SUFFIX_AS_EOS=True,
168
+ SEP='\n',
169
+ STOP_WORDS=['<|end|>']),
170
+ )
tokenization_internlm2.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Tokenization classes for InternLM."""
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, List, Optional, Tuple
21
+
22
+ import sentencepiece as spm
23
+ from transformers.tokenization_utils import PreTrainedTokenizer
24
+ from transformers.utils import logging
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
29
+
30
+ PRETRAINED_VOCAB_FILES_MAP = {}
31
+
32
+
33
+ # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
34
+ class InternLM2Tokenizer(PreTrainedTokenizer):
35
+ """
36
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
37
+
38
+ Args:
39
+ vocab_file (`str`):
40
+ Path to the vocabulary file.
41
+ """
42
+
43
+ vocab_files_names = VOCAB_FILES_NAMES
44
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45
+ model_input_names = ['input_ids', 'attention_mask']
46
+ _auto_class = 'AutoTokenizer'
47
+
48
+ def __init__(
49
+ self,
50
+ vocab_file,
51
+ unk_token='<unk>',
52
+ bos_token='<s>',
53
+ eos_token='</s>',
54
+ pad_token='</s>',
55
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
56
+ add_bos_token=True,
57
+ add_eos_token=False,
58
+ decode_with_prefix_space=False,
59
+ clean_up_tokenization_spaces=False,
60
+ **kwargs,
61
+ ):
62
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
63
+ self.vocab_file = vocab_file
64
+ self.add_bos_token = add_bos_token
65
+ self.add_eos_token = add_eos_token
66
+ self.decode_with_prefix_space = decode_with_prefix_space
67
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
68
+ self.sp_model.Load(vocab_file)
69
+ self._no_prefix_space_tokens = None
70
+ super().__init__(
71
+ bos_token=bos_token,
72
+ eos_token=eos_token,
73
+ unk_token=unk_token,
74
+ pad_token=pad_token,
75
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
76
+ **kwargs,
77
+ )
78
+
79
+ @property
80
+ def no_prefix_space_tokens(self):
81
+ if self._no_prefix_space_tokens is None:
82
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
83
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')}
84
+ return self._no_prefix_space_tokens
85
+
86
+ @property
87
+ def vocab_size(self):
88
+ """Returns vocab size"""
89
+ return self.sp_model.get_piece_size()
90
+
91
+ @property
92
+ def bos_token_id(self) -> Optional[int]:
93
+ return self.sp_model.bos_id()
94
+
95
+ @property
96
+ def eos_token_id(self) -> Optional[int]:
97
+ return self.sp_model.eos_id()
98
+
99
+ def get_vocab(self):
100
+ """Returns vocab as a dict"""
101
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
102
+ vocab.update(self.added_tokens_encoder)
103
+ return vocab
104
+
105
+ def _tokenize(self, text):
106
+ """Returns a tokenized string."""
107
+ return self.sp_model.encode(text, out_type=str)
108
+
109
+ def _convert_token_to_id(self, token):
110
+ """Converts a token (str) in an id using the vocab."""
111
+ return self.sp_model.piece_to_id(token)
112
+
113
+ def _convert_id_to_token(self, index):
114
+ """Converts an index (integer) in a token (str) using the vocab."""
115
+ token = self.sp_model.IdToPiece(index)
116
+ return token
117
+
118
+ def _maybe_add_prefix_space(self, tokens, decoded):
119
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
120
+ return ' ' + decoded
121
+ else:
122
+ return decoded
123
+
124
+ def convert_tokens_to_string(self, tokens):
125
+ """Converts a sequence of tokens (string) in a single string."""
126
+ current_sub_tokens = []
127
+ out_string = ''
128
+ prev_is_special = False
129
+ for token in tokens:
130
+ # make sure that special tokens are not decoded using sentencepiece model
131
+ if token in self.all_special_tokens:
132
+ if not prev_is_special:
133
+ out_string += ' '
134
+ out_string += self.sp_model.decode(current_sub_tokens) + token
135
+ prev_is_special = True
136
+ current_sub_tokens = []
137
+ else:
138
+ current_sub_tokens.append(token)
139
+ prev_is_special = False
140
+ out_string += self.sp_model.decode(current_sub_tokens)
141
+ out_string = self.clean_up_tokenization(out_string)
142
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
143
+ return out_string[1:]
144
+
145
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
146
+ """
147
+ Save the vocabulary and special tokens file to a directory.
148
+
149
+ Args:
150
+ save_directory (`str`):
151
+ The directory in which to save the vocabulary.
152
+
153
+ Returns:
154
+ `Tuple(str)`: Paths to the files saved.
155
+ """
156
+ if not os.path.isdir(save_directory):
157
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
158
+ return
159
+ out_vocab_file = os.path.join(
160
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
161
+ )
162
+
163
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
164
+ copyfile(self.vocab_file, out_vocab_file)
165
+ elif not os.path.isfile(self.vocab_file):
166
+ with open(out_vocab_file, 'wb') as fi:
167
+ content_spiece_model = self.sp_model.serialized_model_proto()
168
+ fi.write(content_spiece_model)
169
+
170
+ return (out_vocab_file,)
171
+
172
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
173
+ if self.add_bos_token:
174
+ bos_token_ids = [self.bos_token_id]
175
+ else:
176
+ bos_token_ids = []
177
+
178
+ output = bos_token_ids + token_ids_0
179
+
180
+ if token_ids_1 is not None:
181
+ output = output + token_ids_1
182
+
183
+ if self.add_eos_token:
184
+ output = output + [self.eos_token_id]
185
+
186
+ return output
187
+
188
+ def get_special_tokens_mask(
189
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
190
+ ) -> List[int]:
191
+ """
192
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
193
+ special tokens using the tokenizer `prepare_for_model` method.
194
+
195
+ Args:
196
+ token_ids_0 (`List[int]`):
197
+ List of IDs.
198
+ token_ids_1 (`List[int]`, *optional*):
199
+ Optional second list of IDs for sequence pairs.
200
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
201
+ Whether or not the token list is already formatted with special tokens for the model.
202
+
203
+ Returns:
204
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
205
+ """
206
+ if already_has_special_tokens:
207
+ return super().get_special_tokens_mask(
208
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
209
+ )
210
+
211
+ if token_ids_1 is None:
212
+ return [1] + ([0] * len(token_ids_0)) + [1]
213
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
214
+
215
+ def create_token_type_ids_from_sequences(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
+ ) -> List[int]:
218
+ """
219
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
220
+ use of token type ids, therefore a list of zeros is returned.
221
+
222
+ Args:
223
+ token_ids_0 (`List[int]`):
224
+ List of IDs.
225
+ token_ids_1 (`List[int]`, *optional*):
226
+ Optional second list of IDs for sequence pairs.
227
+
228
+ Returns:
229
+ `List[int]`: List of zeros.
230
+ """
231
+ eos = [self.eos_token_id]
232
+
233
+ if token_ids_1 is None:
234
+ return len(token_ids_0 + eos) * [0]
235
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
tokenization_internlm2_fast.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """Tokenization Fast class for InternLM."""
18
+ import os
19
+ from shutil import copyfile
20
+ from typing import Any, Dict, Optional, Tuple
21
+
22
+ from tokenizers import Tokenizer, decoders, normalizers, processors
23
+ from tokenizers.models import BPE
24
+ from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
25
+ SentencePieceExtractor,
26
+ SpmConverter)
27
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
28
+ from transformers.utils import logging
29
+
30
+ from .tokenization_internlm2 import InternLM2Tokenizer
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
35
+
36
+
37
+ # Modified from transformers.convert_slow_tokenizer.LlamaConverter
38
+ class InternLM2Converter(SpmConverter):
39
+ handle_byte_fallback = True
40
+
41
+ def vocab(self, proto):
42
+ vocab = [
43
+ ('<unk>', 0.0),
44
+ ('<s>', 0.0),
45
+ ('</s>', 0.0),
46
+ ]
47
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
48
+ return vocab
49
+
50
+ def unk_id(self, proto):
51
+ unk_id = 0
52
+ return unk_id
53
+
54
+ def decoder(self, replacement, add_prefix_space):
55
+ return decoders.Sequence(
56
+ [
57
+ decoders.Replace('▁', ' '),
58
+ decoders.ByteFallback(),
59
+ decoders.Fuse(),
60
+ decoders.Strip(content=' ', left=1),
61
+ ]
62
+ )
63
+
64
+ def tokenizer(self, proto):
65
+ model_type = proto.trainer_spec.model_type
66
+ vocab_scores = self.vocab(proto)
67
+ # special tokens
68
+ added_tokens = self.original_tokenizer.added_tokens_decoder
69
+ for i in range(len(vocab_scores)):
70
+ piece, score = vocab_scores[i]
71
+ if i in added_tokens:
72
+ vocab_scores[i] = (added_tokens[i].content, score)
73
+ if model_type == 1:
74
+ raise RuntimeError('InternLM2 is supposed to be a BPE model!')
75
+
76
+ elif model_type == 2:
77
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
78
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
79
+ tokenizer = Tokenizer(
80
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
81
+ )
82
+ tokenizer.add_special_tokens(
83
+ [ added_token for index, added_token in added_tokens.items()]
84
+ )
85
+ else:
86
+ raise Exception(
87
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
88
+ )
89
+
90
+ return tokenizer
91
+
92
+ def normalizer(self, proto):
93
+ normalizers_list = []
94
+ if proto.normalizer_spec.add_dummy_prefix:
95
+ normalizers_list.append(normalizers.Prepend(prepend='▁'))
96
+ normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
97
+ return normalizers.Sequence(normalizers_list)
98
+
99
+ def pre_tokenizer(self, replacement, add_prefix_space):
100
+ return None
101
+
102
+
103
+ SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
104
+
105
+
106
+ # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
107
+ class InternLM2TokenizerFast(PreTrainedTokenizerFast):
108
+ vocab_files_names = VOCAB_FILES_NAMES
109
+ slow_tokenizer_class = InternLM2Tokenizer
110
+ padding_side = 'left'
111
+ model_input_names = ['input_ids', 'attention_mask']
112
+ _auto_class = 'AutoTokenizer'
113
+
114
+ def __init__(
115
+ self,
116
+ vocab_file,
117
+ unk_token='<unk>',
118
+ bos_token='<s>',
119
+ eos_token='</s>',
120
+ pad_token='</s>',
121
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
122
+ add_bos_token=True,
123
+ add_eos_token=False,
124
+ decode_with_prefix_space=False,
125
+ clean_up_tokenization_spaces=False,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(
129
+ vocab_file=vocab_file,
130
+ unk_token=unk_token,
131
+ bos_token=bos_token,
132
+ eos_token=eos_token,
133
+ pad_token=pad_token,
134
+ sp_model_kwargs=sp_model_kwargs,
135
+ add_bos_token=add_bos_token,
136
+ add_eos_token=add_eos_token,
137
+ decode_with_prefix_space=decode_with_prefix_space,
138
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
139
+ **kwargs,
140
+ )
141
+ self._add_bos_token = add_bos_token
142
+ self._add_eos_token = add_eos_token
143
+ self.update_post_processor()
144
+ self.vocab_file = vocab_file
145
+
146
+ @property
147
+ def can_save_slow_tokenizer(self) -> bool:
148
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
149
+
150
+ def update_post_processor(self):
151
+ """
152
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
153
+ """
154
+ bos = self.bos_token
155
+ bos_token_id = self.bos_token_id
156
+ if bos is None and self.add_bos_token:
157
+ raise ValueError('add_bos_token = True but bos_token = None')
158
+
159
+ eos = self.eos_token
160
+ eos_token_id = self.eos_token_id
161
+ if eos is None and self.add_eos_token:
162
+ raise ValueError('add_eos_token = True but eos_token = None')
163
+
164
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
165
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
166
+
167
+ special_tokens = []
168
+ if self.add_bos_token:
169
+ special_tokens.append((bos, bos_token_id))
170
+ if self.add_eos_token:
171
+ special_tokens.append((eos, eos_token_id))
172
+ self._tokenizer.post_processor = processors.TemplateProcessing(
173
+ single=single, pair=pair, special_tokens=special_tokens
174
+ )
175
+
176
+ @property
177
+ def add_eos_token(self):
178
+ return self._add_eos_token
179
+
180
+ @property
181
+ def add_bos_token(self):
182
+ return self._add_bos_token
183
+
184
+ @add_eos_token.setter
185
+ def add_eos_token(self, value):
186
+ self._add_eos_token = value
187
+ self.update_post_processor()
188
+
189
+ @add_bos_token.setter
190
+ def add_bos_token(self, value):
191
+ self._add_bos_token = value
192
+ self.update_post_processor()
193
+
194
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
195
+ if not self.can_save_slow_tokenizer:
196
+ raise ValueError(
197
+ 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
198
+ 'tokenizer.'
199
+ )
200
+
201
+ if not os.path.isdir(save_directory):
202
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
203
+ return
204
+ out_vocab_file = os.path.join(
205
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
206
+ )
207
+
208
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
209
+ copyfile(self.vocab_file, out_vocab_file)
210
+
211
+ return (out_vocab_file,)
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d257d75be50ec94137a76982b1ba699695a69d25a660733e8d0e2073bf50328b
3
+ size 11443325
tokenizer_config.json ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "151643": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "151644": {
15
+ "content": "<|im_start|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "151645": {
23
+ "content": "<|im_end|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "151646": {
31
+ "content": "<|object_ref_start|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "151647": {
39
+ "content": "<|object_ref_end|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "151648": {
47
+ "content": "<|box_start|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "151649": {
55
+ "content": "<|box_end|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": true
61
+ },
62
+ "151650": {
63
+ "content": "<|quad_start|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": true
69
+ },
70
+ "151651": {
71
+ "content": "<|quad_end|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": true
77
+ },
78
+ "151652": {
79
+ "content": "<|vision_start|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": true
85
+ },
86
+ "151653": {
87
+ "content": "<|vision_end|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": true
93
+ },
94
+ "151654": {
95
+ "content": "<|vision_pad|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": true
101
+ },
102
+ "151655": {
103
+ "content": "<|image_pad|>",
104
+ "lstrip": false,
105
+ "normalized": false,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": true
109
+ },
110
+ "151656": {
111
+ "content": "<|video_pad|>",
112
+ "lstrip": false,
113
+ "normalized": false,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": true
117
+ },
118
+ "151657": {
119
+ "content": "<tool_call>",
120
+ "lstrip": false,
121
+ "normalized": false,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "151658": {
127
+ "content": "</tool_call>",
128
+ "lstrip": false,
129
+ "normalized": false,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "151659": {
135
+ "content": "<|fim_prefix|>",
136
+ "lstrip": false,
137
+ "normalized": false,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "151660": {
143
+ "content": "<|fim_middle|>",
144
+ "lstrip": false,
145
+ "normalized": false,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "151661": {
151
+ "content": "<|fim_suffix|>",
152
+ "lstrip": false,
153
+ "normalized": false,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "151662": {
159
+ "content": "<|fim_pad|>",
160
+ "lstrip": false,
161
+ "normalized": false,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "151663": {
167
+ "content": "<|repo_name|>",
168
+ "lstrip": false,
169
+ "normalized": false,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "151664": {
175
+ "content": "<|file_sep|>",
176
+ "lstrip": false,
177
+ "normalized": false,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "151665": {
183
+ "content": "<img>",
184
+ "lstrip": false,
185
+ "normalized": false,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": true
189
+ },
190
+ "151666": {
191
+ "content": "</img>",
192
+ "lstrip": false,
193
+ "normalized": false,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": true
197
+ },
198
+ "151667": {
199
+ "content": "<IMG_CONTEXT>",
200
+ "lstrip": false,
201
+ "normalized": false,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": true
205
+ },
206
+ "151668": {
207
+ "content": "<quad>",
208
+ "lstrip": false,
209
+ "normalized": false,
210
+ "rstrip": false,
211
+ "single_word": false,
212
+ "special": true
213
+ },
214
+ "151669": {
215
+ "content": "</quad>",
216
+ "lstrip": false,
217
+ "normalized": false,
218
+ "rstrip": false,
219
+ "single_word": false,
220
+ "special": true
221
+ },
222
+ "151670": {
223
+ "content": "<ref>",
224
+ "lstrip": false,
225
+ "normalized": false,
226
+ "rstrip": false,
227
+ "single_word": false,
228
+ "special": true
229
+ },
230
+ "151671": {
231
+ "content": "</ref>",
232
+ "lstrip": false,
233
+ "normalized": false,
234
+ "rstrip": false,
235
+ "single_word": false,
236
+ "special": true
237
+ },
238
+ "151672": {
239
+ "content": "<box>",
240
+ "lstrip": false,
241
+ "normalized": false,
242
+ "rstrip": false,
243
+ "single_word": false,
244
+ "special": true
245
+ },
246
+ "151673": {
247
+ "content": "</box>",
248
+ "lstrip": false,
249
+ "normalized": false,
250
+ "rstrip": false,
251
+ "single_word": false,
252
+ "special": true
253
+ },
254
+ "151674": {
255
+ "content": "<p>",
256
+ "lstrip": false,
257
+ "normalized": false,
258
+ "rstrip": false,
259
+ "single_word": false,
260
+ "special": true
261
+ },
262
+ "151675": {
263
+ "content": "</p>",
264
+ "lstrip": false,
265
+ "normalized": false,
266
+ "rstrip": false,
267
+ "single_word": false,
268
+ "special": true
269
+ },
270
+ "151676": {
271
+ "content": "[CLS]",
272
+ "lstrip": false,
273
+ "normalized": false,
274
+ "rstrip": false,
275
+ "single_word": false,
276
+ "special": true
277
+ },
278
+ "151677": {
279
+ "content": "[BG_CLS]",
280
+ "lstrip": false,
281
+ "normalized": false,
282
+ "rstrip": false,
283
+ "single_word": false,
284
+ "special": true
285
+ },
286
+ "151678": {
287
+ "content": "<obj>",
288
+ "lstrip": false,
289
+ "normalized": false,
290
+ "rstrip": false,
291
+ "single_word": false,
292
+ "special": true
293
+ },
294
+ "151679": {
295
+ "content": "</obj>",
296
+ "lstrip": false,
297
+ "normalized": false,
298
+ "rstrip": false,
299
+ "single_word": false,
300
+ "special": true
301
+ },
302
+ "151680": {
303
+ "content": "<OBJ_CONTEXT>",
304
+ "lstrip": false,
305
+ "normalized": false,
306
+ "rstrip": false,
307
+ "single_word": false,
308
+ "special": true
309
+ },
310
+ "151681": {
311
+ "content": "[SEG000]",
312
+ "lstrip": false,
313
+ "normalized": false,
314
+ "rstrip": false,
315
+ "single_word": false,
316
+ "special": true
317
+ },
318
+ "151682": {
319
+ "content": "[SEG001]",
320
+ "lstrip": false,
321
+ "normalized": false,
322
+ "rstrip": false,
323
+ "single_word": false,
324
+ "special": true
325
+ },
326
+ "151683": {
327
+ "content": "[SEG002]",
328
+ "lstrip": false,
329
+ "normalized": false,
330
+ "rstrip": false,
331
+ "single_word": false,
332
+ "special": true
333
+ },
334
+ "151684": {
335
+ "content": "[SEG003]",
336
+ "lstrip": false,
337
+ "normalized": false,
338
+ "rstrip": false,
339
+ "single_word": false,
340
+ "special": true
341
+ },
342
+ "151685": {
343
+ "content": "[SEG004]",
344
+ "lstrip": false,
345
+ "normalized": false,
346
+ "rstrip": false,
347
+ "single_word": false,
348
+ "special": true
349
+ },
350
+ "151686": {
351
+ "content": "[SEG005]",
352
+ "lstrip": false,
353
+ "normalized": false,
354
+ "rstrip": false,
355
+ "single_word": false,
356
+ "special": true
357
+ },
358
+ "151687": {
359
+ "content": "[SEG006]",
360
+ "lstrip": false,
361
+ "normalized": false,
362
+ "rstrip": false,
363
+ "single_word": false,
364
+ "special": true
365
+ },
366
+ "151688": {
367
+ "content": "[SEG007]",
368
+ "lstrip": false,
369
+ "normalized": false,
370
+ "rstrip": false,
371
+ "single_word": false,
372
+ "special": true
373
+ },
374
+ "151689": {
375
+ "content": "[SEG008]",
376
+ "lstrip": false,
377
+ "normalized": false,
378
+ "rstrip": false,
379
+ "single_word": false,
380
+ "special": true
381
+ },
382
+ "151690": {
383
+ "content": "[SEG009]",
384
+ "lstrip": false,
385
+ "normalized": false,
386
+ "rstrip": false,
387
+ "single_word": false,
388
+ "special": true
389
+ },
390
+ "151691": {
391
+ "content": "[SEG010]",
392
+ "lstrip": false,
393
+ "normalized": false,
394
+ "rstrip": false,
395
+ "single_word": false,
396
+ "special": true
397
+ },
398
+ "151692": {
399
+ "content": "[SEG011]",
400
+ "lstrip": false,
401
+ "normalized": false,
402
+ "rstrip": false,
403
+ "single_word": false,
404
+ "special": true
405
+ },
406
+ "151693": {
407
+ "content": "[SEG012]",
408
+ "lstrip": false,
409
+ "normalized": false,
410
+ "rstrip": false,
411
+ "single_word": false,
412
+ "special": true
413
+ },
414
+ "151694": {
415
+ "content": "[SEG013]",
416
+ "lstrip": false,
417
+ "normalized": false,
418
+ "rstrip": false,
419
+ "single_word": false,
420
+ "special": true
421
+ },
422
+ "151695": {
423
+ "content": "[SEG014]",
424
+ "lstrip": false,
425
+ "normalized": false,
426
+ "rstrip": false,
427
+ "single_word": false,
428
+ "special": true
429
+ },
430
+ "151696": {
431
+ "content": "[SEG015]",
432
+ "lstrip": false,
433
+ "normalized": false,
434
+ "rstrip": false,
435
+ "single_word": false,
436
+ "special": true
437
+ },
438
+ "151697": {
439
+ "content": "[SEG016]",
440
+ "lstrip": false,
441
+ "normalized": false,
442
+ "rstrip": false,
443
+ "single_word": false,
444
+ "special": true
445
+ },
446
+ "151698": {
447
+ "content": "[SEG017]",
448
+ "lstrip": false,
449
+ "normalized": false,
450
+ "rstrip": false,
451
+ "single_word": false,
452
+ "special": true
453
+ },
454
+ "151699": {
455
+ "content": "[SEG018]",
456
+ "lstrip": false,
457
+ "normalized": false,
458
+ "rstrip": false,
459
+ "single_word": false,
460
+ "special": true
461
+ },
462
+ "151700": {
463
+ "content": "[SEG019]",
464
+ "lstrip": false,
465
+ "normalized": false,
466
+ "rstrip": false,
467
+ "single_word": false,
468
+ "special": true
469
+ },
470
+ "151701": {
471
+ "content": "[SEG020]",
472
+ "lstrip": false,
473
+ "normalized": false,
474
+ "rstrip": false,
475
+ "single_word": false,
476
+ "special": true
477
+ },
478
+ "151702": {
479
+ "content": "[SEG021]",
480
+ "lstrip": false,
481
+ "normalized": false,
482
+ "rstrip": false,
483
+ "single_word": false,
484
+ "special": true
485
+ },
486
+ "151703": {
487
+ "content": "[SEG022]",
488
+ "lstrip": false,
489
+ "normalized": false,
490
+ "rstrip": false,
491
+ "single_word": false,
492
+ "special": true
493
+ },
494
+ "151704": {
495
+ "content": "[SEG023]",
496
+ "lstrip": false,
497
+ "normalized": false,
498
+ "rstrip": false,
499
+ "single_word": false,
500
+ "special": true
501
+ },
502
+ "151705": {
503
+ "content": "[SEG024]",
504
+ "lstrip": false,
505
+ "normalized": false,
506
+ "rstrip": false,
507
+ "single_word": false,
508
+ "special": true
509
+ },
510
+ "151706": {
511
+ "content": "[SEG025]",
512
+ "lstrip": false,
513
+ "normalized": false,
514
+ "rstrip": false,
515
+ "single_word": false,
516
+ "special": true
517
+ },
518
+ "151707": {
519
+ "content": "[SEG026]",
520
+ "lstrip": false,
521
+ "normalized": false,
522
+ "rstrip": false,
523
+ "single_word": false,
524
+ "special": true
525
+ },
526
+ "151708": {
527
+ "content": "[SEG027]",
528
+ "lstrip": false,
529
+ "normalized": false,
530
+ "rstrip": false,
531
+ "single_word": false,
532
+ "special": true
533
+ },
534
+ "151709": {
535
+ "content": "[SEG028]",
536
+ "lstrip": false,
537
+ "normalized": false,
538
+ "rstrip": false,
539
+ "single_word": false,
540
+ "special": true
541
+ },
542
+ "151710": {
543
+ "content": "[SEG029]",
544
+ "lstrip": false,
545
+ "normalized": false,
546
+ "rstrip": false,
547
+ "single_word": false,
548
+ "special": true
549
+ },
550
+ "151711": {
551
+ "content": "[SEG030]",
552
+ "lstrip": false,
553
+ "normalized": false,
554
+ "rstrip": false,
555
+ "single_word": false,
556
+ "special": true
557
+ },
558
+ "151712": {
559
+ "content": "[SEG031]",
560
+ "lstrip": false,
561
+ "normalized": false,
562
+ "rstrip": false,
563
+ "single_word": false,
564
+ "special": true
565
+ },
566
+ "151713": {
567
+ "content": "[SEG032]",
568
+ "lstrip": false,
569
+ "normalized": false,
570
+ "rstrip": false,
571
+ "single_word": false,
572
+ "special": true
573
+ },
574
+ "151714": {
575
+ "content": "[SEG033]",
576
+ "lstrip": false,
577
+ "normalized": false,
578
+ "rstrip": false,
579
+ "single_word": false,
580
+ "special": true
581
+ },
582
+ "151715": {
583
+ "content": "[SEG034]",
584
+ "lstrip": false,
585
+ "normalized": false,
586
+ "rstrip": false,
587
+ "single_word": false,
588
+ "special": true
589
+ },
590
+ "151716": {
591
+ "content": "[SEG035]",
592
+ "lstrip": false,
593
+ "normalized": false,
594
+ "rstrip": false,
595
+ "single_word": false,
596
+ "special": true
597
+ },
598
+ "151717": {
599
+ "content": "[SEG036]",
600
+ "lstrip": false,
601
+ "normalized": false,
602
+ "rstrip": false,
603
+ "single_word": false,
604
+ "special": true
605
+ },
606
+ "151718": {
607
+ "content": "[SEG037]",
608
+ "lstrip": false,
609
+ "normalized": false,
610
+ "rstrip": false,
611
+ "single_word": false,
612
+ "special": true
613
+ },
614
+ "151719": {
615
+ "content": "[SEG038]",
616
+ "lstrip": false,
617
+ "normalized": false,
618
+ "rstrip": false,
619
+ "single_word": false,
620
+ "special": true
621
+ },
622
+ "151720": {
623
+ "content": "[SEG039]",
624
+ "lstrip": false,
625
+ "normalized": false,
626
+ "rstrip": false,
627
+ "single_word": false,
628
+ "special": true
629
+ },
630
+ "151721": {
631
+ "content": "[SEG040]",
632
+ "lstrip": false,
633
+ "normalized": false,
634
+ "rstrip": false,
635
+ "single_word": false,
636
+ "special": true
637
+ },
638
+ "151722": {
639
+ "content": "[SEG041]",
640
+ "lstrip": false,
641
+ "normalized": false,
642
+ "rstrip": false,
643
+ "single_word": false,
644
+ "special": true
645
+ },
646
+ "151723": {
647
+ "content": "[SEG042]",
648
+ "lstrip": false,
649
+ "normalized": false,
650
+ "rstrip": false,
651
+ "single_word": false,
652
+ "special": true
653
+ },
654
+ "151724": {
655
+ "content": "[SEG043]",
656
+ "lstrip": false,
657
+ "normalized": false,
658
+ "rstrip": false,
659
+ "single_word": false,
660
+ "special": true
661
+ },
662
+ "151725": {
663
+ "content": "[SEG044]",
664
+ "lstrip": false,
665
+ "normalized": false,
666
+ "rstrip": false,
667
+ "single_word": false,
668
+ "special": true
669
+ },
670
+ "151726": {
671
+ "content": "[SEG045]",
672
+ "lstrip": false,
673
+ "normalized": false,
674
+ "rstrip": false,
675
+ "single_word": false,
676
+ "special": true
677
+ },
678
+ "151727": {
679
+ "content": "[SEG046]",
680
+ "lstrip": false,
681
+ "normalized": false,
682
+ "rstrip": false,
683
+ "single_word": false,
684
+ "special": true
685
+ },
686
+ "151728": {
687
+ "content": "[SEG047]",
688
+ "lstrip": false,
689
+ "normalized": false,
690
+ "rstrip": false,
691
+ "single_word": false,
692
+ "special": true
693
+ },
694
+ "151729": {
695
+ "content": "[SEG048]",
696
+ "lstrip": false,
697
+ "normalized": false,
698
+ "rstrip": false,
699
+ "single_word": false,
700
+ "special": true
701
+ },
702
+ "151730": {
703
+ "content": "[SEG049]",
704
+ "lstrip": false,
705
+ "normalized": false,
706
+ "rstrip": false,
707
+ "single_word": false,
708
+ "special": true
709
+ },
710
+ "151731": {
711
+ "content": "[SEG050]",
712
+ "lstrip": false,
713
+ "normalized": false,
714
+ "rstrip": false,
715
+ "single_word": false,
716
+ "special": true
717
+ },
718
+ "151732": {
719
+ "content": "[SEG051]",
720
+ "lstrip": false,
721
+ "normalized": false,
722
+ "rstrip": false,
723
+ "single_word": false,
724
+ "special": true
725
+ },
726
+ "151733": {
727
+ "content": "[SEG052]",
728
+ "lstrip": false,
729
+ "normalized": false,
730
+ "rstrip": false,
731
+ "single_word": false,
732
+ "special": true
733
+ },
734
+ "151734": {
735
+ "content": "[SEG053]",
736
+ "lstrip": false,
737
+ "normalized": false,
738
+ "rstrip": false,
739
+ "single_word": false,
740
+ "special": true
741
+ },
742
+ "151735": {
743
+ "content": "[SEG054]",
744
+ "lstrip": false,
745
+ "normalized": false,
746
+ "rstrip": false,
747
+ "single_word": false,
748
+ "special": true
749
+ },
750
+ "151736": {
751
+ "content": "[SEG055]",
752
+ "lstrip": false,
753
+ "normalized": false,
754
+ "rstrip": false,
755
+ "single_word": false,
756
+ "special": true
757
+ },
758
+ "151737": {
759
+ "content": "[SEG056]",
760
+ "lstrip": false,
761
+ "normalized": false,
762
+ "rstrip": false,
763
+ "single_word": false,
764
+ "special": true
765
+ },
766
+ "151738": {
767
+ "content": "[SEG057]",
768
+ "lstrip": false,
769
+ "normalized": false,
770
+ "rstrip": false,
771
+ "single_word": false,
772
+ "special": true
773
+ },
774
+ "151739": {
775
+ "content": "[SEG058]",
776
+ "lstrip": false,
777
+ "normalized": false,
778
+ "rstrip": false,
779
+ "single_word": false,
780
+ "special": true
781
+ },
782
+ "151740": {
783
+ "content": "[SEG059]",
784
+ "lstrip": false,
785
+ "normalized": false,
786
+ "rstrip": false,
787
+ "single_word": false,
788
+ "special": true
789
+ },
790
+ "151741": {
791
+ "content": "[SEG060]",
792
+ "lstrip": false,
793
+ "normalized": false,
794
+ "rstrip": false,
795
+ "single_word": false,
796
+ "special": true
797
+ },
798
+ "151742": {
799
+ "content": "[SEG061]",
800
+ "lstrip": false,
801
+ "normalized": false,
802
+ "rstrip": false,
803
+ "single_word": false,
804
+ "special": true
805
+ },
806
+ "151743": {
807
+ "content": "[SEG062]",
808
+ "lstrip": false,
809
+ "normalized": false,
810
+ "rstrip": false,
811
+ "single_word": false,
812
+ "special": true
813
+ },
814
+ "151744": {
815
+ "content": "[SEG063]",
816
+ "lstrip": false,
817
+ "normalized": false,
818
+ "rstrip": false,
819
+ "single_word": false,
820
+ "special": true
821
+ },
822
+ "151745": {
823
+ "content": "[SEG064]",
824
+ "lstrip": false,
825
+ "normalized": false,
826
+ "rstrip": false,
827
+ "single_word": false,
828
+ "special": true
829
+ },
830
+ "151746": {
831
+ "content": "[SEG065]",
832
+ "lstrip": false,
833
+ "normalized": false,
834
+ "rstrip": false,
835
+ "single_word": false,
836
+ "special": true
837
+ },
838
+ "151747": {
839
+ "content": "[SEG066]",
840
+ "lstrip": false,
841
+ "normalized": false,
842
+ "rstrip": false,
843
+ "single_word": false,
844
+ "special": true
845
+ },
846
+ "151748": {
847
+ "content": "[SEG067]",
848
+ "lstrip": false,
849
+ "normalized": false,
850
+ "rstrip": false,
851
+ "single_word": false,
852
+ "special": true
853
+ },
854
+ "151749": {
855
+ "content": "[SEG068]",
856
+ "lstrip": false,
857
+ "normalized": false,
858
+ "rstrip": false,
859
+ "single_word": false,
860
+ "special": true
861
+ },
862
+ "151750": {
863
+ "content": "[SEG069]",
864
+ "lstrip": false,
865
+ "normalized": false,
866
+ "rstrip": false,
867
+ "single_word": false,
868
+ "special": true
869
+ },
870
+ "151751": {
871
+ "content": "[SEG070]",
872
+ "lstrip": false,
873
+ "normalized": false,
874
+ "rstrip": false,
875
+ "single_word": false,
876
+ "special": true
877
+ },
878
+ "151752": {
879
+ "content": "[SEG071]",
880
+ "lstrip": false,
881
+ "normalized": false,
882
+ "rstrip": false,
883
+ "single_word": false,
884
+ "special": true
885
+ },
886
+ "151753": {
887
+ "content": "[SEG072]",
888
+ "lstrip": false,
889
+ "normalized": false,
890
+ "rstrip": false,
891
+ "single_word": false,
892
+ "special": true
893
+ },
894
+ "151754": {
895
+ "content": "[SEG073]",
896
+ "lstrip": false,
897
+ "normalized": false,
898
+ "rstrip": false,
899
+ "single_word": false,
900
+ "special": true
901
+ },
902
+ "151755": {
903
+ "content": "[SEG074]",
904
+ "lstrip": false,
905
+ "normalized": false,
906
+ "rstrip": false,
907
+ "single_word": false,
908
+ "special": true
909
+ },
910
+ "151756": {
911
+ "content": "[SEG075]",
912
+ "lstrip": false,
913
+ "normalized": false,
914
+ "rstrip": false,
915
+ "single_word": false,
916
+ "special": true
917
+ },
918
+ "151757": {
919
+ "content": "[SEG076]",
920
+ "lstrip": false,
921
+ "normalized": false,
922
+ "rstrip": false,
923
+ "single_word": false,
924
+ "special": true
925
+ },
926
+ "151758": {
927
+ "content": "[SEG077]",
928
+ "lstrip": false,
929
+ "normalized": false,
930
+ "rstrip": false,
931
+ "single_word": false,
932
+ "special": true
933
+ },
934
+ "151759": {
935
+ "content": "[SEG078]",
936
+ "lstrip": false,
937
+ "normalized": false,
938
+ "rstrip": false,
939
+ "single_word": false,
940
+ "special": true
941
+ },
942
+ "151760": {
943
+ "content": "[SEG079]",
944
+ "lstrip": false,
945
+ "normalized": false,
946
+ "rstrip": false,
947
+ "single_word": false,
948
+ "special": true
949
+ },
950
+ "151761": {
951
+ "content": "[SEG080]",
952
+ "lstrip": false,
953
+ "normalized": false,
954
+ "rstrip": false,
955
+ "single_word": false,
956
+ "special": true
957
+ },
958
+ "151762": {
959
+ "content": "[SEG081]",
960
+ "lstrip": false,
961
+ "normalized": false,
962
+ "rstrip": false,
963
+ "single_word": false,
964
+ "special": true
965
+ },
966
+ "151763": {
967
+ "content": "[SEG082]",
968
+ "lstrip": false,
969
+ "normalized": false,
970
+ "rstrip": false,
971
+ "single_word": false,
972
+ "special": true
973
+ },
974
+ "151764": {
975
+ "content": "[SEG083]",
976
+ "lstrip": false,
977
+ "normalized": false,
978
+ "rstrip": false,
979
+ "single_word": false,
980
+ "special": true
981
+ },
982
+ "151765": {
983
+ "content": "[SEG084]",
984
+ "lstrip": false,
985
+ "normalized": false,
986
+ "rstrip": false,
987
+ "single_word": false,
988
+ "special": true
989
+ },
990
+ "151766": {
991
+ "content": "[SEG085]",
992
+ "lstrip": false,
993
+ "normalized": false,
994
+ "rstrip": false,
995
+ "single_word": false,
996
+ "special": true
997
+ },
998
+ "151767": {
999
+ "content": "[SEG086]",
1000
+ "lstrip": false,
1001
+ "normalized": false,
1002
+ "rstrip": false,
1003
+ "single_word": false,
1004
+ "special": true
1005
+ },
1006
+ "151768": {
1007
+ "content": "[SEG087]",
1008
+ "lstrip": false,
1009
+ "normalized": false,
1010
+ "rstrip": false,
1011
+ "single_word": false,
1012
+ "special": true
1013
+ },
1014
+ "151769": {
1015
+ "content": "[SEG088]",
1016
+ "lstrip": false,
1017
+ "normalized": false,
1018
+ "rstrip": false,
1019
+ "single_word": false,
1020
+ "special": true
1021
+ },
1022
+ "151770": {
1023
+ "content": "[SEG089]",
1024
+ "lstrip": false,
1025
+ "normalized": false,
1026
+ "rstrip": false,
1027
+ "single_word": false,
1028
+ "special": true
1029
+ },
1030
+ "151771": {
1031
+ "content": "[SEG090]",
1032
+ "lstrip": false,
1033
+ "normalized": false,
1034
+ "rstrip": false,
1035
+ "single_word": false,
1036
+ "special": true
1037
+ },
1038
+ "151772": {
1039
+ "content": "[SEG091]",
1040
+ "lstrip": false,
1041
+ "normalized": false,
1042
+ "rstrip": false,
1043
+ "single_word": false,
1044
+ "special": true
1045
+ },
1046
+ "151773": {
1047
+ "content": "[SEG092]",
1048
+ "lstrip": false,
1049
+ "normalized": false,
1050
+ "rstrip": false,
1051
+ "single_word": false,
1052
+ "special": true
1053
+ },
1054
+ "151774": {
1055
+ "content": "[SEG093]",
1056
+ "lstrip": false,
1057
+ "normalized": false,
1058
+ "rstrip": false,
1059
+ "single_word": false,
1060
+ "special": true
1061
+ },
1062
+ "151775": {
1063
+ "content": "[SEG094]",
1064
+ "lstrip": false,
1065
+ "normalized": false,
1066
+ "rstrip": false,
1067
+ "single_word": false,
1068
+ "special": true
1069
+ },
1070
+ "151776": {
1071
+ "content": "[SEG095]",
1072
+ "lstrip": false,
1073
+ "normalized": false,
1074
+ "rstrip": false,
1075
+ "single_word": false,
1076
+ "special": true
1077
+ },
1078
+ "151777": {
1079
+ "content": "[SEG096]",
1080
+ "lstrip": false,
1081
+ "normalized": false,
1082
+ "rstrip": false,
1083
+ "single_word": false,
1084
+ "special": true
1085
+ },
1086
+ "151778": {
1087
+ "content": "[SEG097]",
1088
+ "lstrip": false,
1089
+ "normalized": false,
1090
+ "rstrip": false,
1091
+ "single_word": false,
1092
+ "special": true
1093
+ },
1094
+ "151779": {
1095
+ "content": "[SEG098]",
1096
+ "lstrip": false,
1097
+ "normalized": false,
1098
+ "rstrip": false,
1099
+ "single_word": false,
1100
+ "special": true
1101
+ },
1102
+ "151780": {
1103
+ "content": "[SEG099]",
1104
+ "lstrip": false,
1105
+ "normalized": false,
1106
+ "rstrip": false,
1107
+ "single_word": false,
1108
+ "special": true
1109
+ }
1110
+ },
1111
+ "additional_special_tokens": [
1112
+ "<|im_start|>",
1113
+ "<|im_end|>",
1114
+ "<|object_ref_start|>",
1115
+ "<|object_ref_end|>",
1116
+ "<|box_start|>",
1117
+ "<|box_end|>",
1118
+ "<|quad_start|>",
1119
+ "<|quad_end|>",
1120
+ "<|vision_start|>",
1121
+ "<|vision_end|>",
1122
+ "<|vision_pad|>",
1123
+ "<|image_pad|>",
1124
+ "<|video_pad|>",
1125
+ "<img>",
1126
+ "</img>",
1127
+ "<IMG_CONTEXT>",
1128
+ "<quad>",
1129
+ "</quad>",
1130
+ "<ref>",
1131
+ "</ref>",
1132
+ "<box>",
1133
+ "</box>"
1134
+ ],
1135
+ "bos_token": null,
1136
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
1137
+ "clean_up_tokenization_spaces": false,
1138
+ "eos_token": "<|im_end|>",
1139
+ "errors": "replace",
1140
+ "extra_special_tokens": {},
1141
+ "model_max_length": 16384,
1142
+ "pad_token": "<|endoftext|>",
1143
+ "padding_side": "right",
1144
+ "split_special_tokens": false,
1145
+ "tokenizer_class": "Qwen2Tokenizer",
1146
+ "unk_token": null
1147
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff