SFLY5 commited on
Commit
c2771cf
·
verified ·
1 Parent(s): 9914892

Add files using upload-large-folder tool

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. processing_ernie_45t_vl.py +1352 -13
  3. tokenizer_config.json +1 -1
config.json CHANGED
@@ -7,7 +7,7 @@
7
  "AutoModel": "modeling_ernie_45t_vl.Ernie4_5_VLMoeForConditionalGeneration",
8
  "AutoModelForCausalLM": "modeling_ernie_45t_vl.Ernie4_5_VLMoeForConditionalGeneration",
9
  "AutoProcessor": "processing_ernie_45t_vl.Ernie_45T_VLProcessor",
10
- "AutoImageProcessor": "image_processing_ernie_45t_vl.Ernie_45T_VLImageProcessor"
11
  },
12
  "torch_dtype": "bfloat16",
13
  "hidden_act": "silu",
 
7
  "AutoModel": "modeling_ernie_45t_vl.Ernie4_5_VLMoeForConditionalGeneration",
8
  "AutoModelForCausalLM": "modeling_ernie_45t_vl.Ernie4_5_VLMoeForConditionalGeneration",
9
  "AutoProcessor": "processing_ernie_45t_vl.Ernie_45T_VLProcessor",
10
+ "AutoImageProcessor": "processing_ernie_45t_vl.Ernie_45T_VLImageProcessor"
11
  },
12
  "torch_dtype": "bfloat16",
13
  "hidden_act": "silu",
processing_ernie_45t_vl.py CHANGED
@@ -12,30 +12,1369 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- """Processor class for Ernie_45T_VL."""
16
 
17
  import copy
18
  import io
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  import numpy as np
21
  import torch
22
- from PIL import Image
 
23
  from collections import defaultdict
24
  from typing import Any, Dict, List, Union
 
 
25
 
26
- from .image_processing_ernie_45t_vl import Ernie_45T_VLImageProcessor
27
- from .tokenization_ernie_45t_vl import Ernie4_5_VLTokenizer
28
- from .video_utils_ernie_45t_vl import (
29
- read_frames_decord,
30
- read_video_decord,
31
- RAW_IMAGE_DIR,
32
- get_downloadable,
33
- render_frame_timestamp,
34
- )
35
 
36
- from transformers.image_utils import ChannelDimension
 
 
 
 
 
 
 
 
37
  from transformers.processing_utils import ProcessorMixin
38
  from transformers.feature_extraction_utils import BatchFeature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3}
@@ -472,4 +1811,4 @@ class Ernie_45T_VLProcessor(ProcessorMixin):
472
  return list(tokenizer_input_names) + list(image_processor_input_names)
473
 
474
 
475
- __all__ = ["Ernie_45T_VLProcessor"]
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ """Tokenization classes and Image processor class, Processor class for Ernie_45T_VL."""
16
 
17
  import copy
18
  import io
19
+ import os
20
+ import re
21
+ import math
22
+ import random
23
+ import requests
24
+ import base64
25
+ import datetime
26
+ import hashlib
27
+ import threading
28
+ import uuid
29
+ import decord
30
+ from shutil import copyfile
31
+ from typing import Dict, List, Optional, Tuple, Union
32
 
33
  import numpy as np
34
  import torch
35
+ from PIL import Image, ImageDraw, ImageFont
36
+ from PIL.ExifTags import TAGS
37
  from collections import defaultdict
38
  from typing import Any, Dict, List, Union
39
+ from pathlib import Path
40
+ from tempfile import NamedTemporaryFile as ntf
41
 
42
+ try:
43
+ # moviepy 1.0
44
+ import moviepy.editor as mp
45
+ except:
46
+ # moviepy 2.0
47
+ import moviepy as mp
 
 
 
48
 
49
+ import sentencepiece as spm
50
+ from transformers.tokenization_utils import PreTrainedTokenizer
51
+ from transformers.tokenization_utils_base import (
52
+ PaddingStrategy,
53
+ TextInput,
54
+ )
55
+ from transformers.utils import logging
56
+ from transformers.utils import TensorType, logging
57
+ from transformers.video_utils import VideoInput
58
  from transformers.processing_utils import ProcessorMixin
59
  from transformers.feature_extraction_utils import BatchFeature
60
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
61
+ from transformers.image_transforms import (
62
+ convert_to_rgb,
63
+ normalize,
64
+ rescale,
65
+ resize,
66
+ to_channel_dimension_format,
67
+ )
68
+ from transformers.image_utils import (
69
+ OPENAI_CLIP_MEAN,
70
+ OPENAI_CLIP_STD,
71
+ ChannelDimension,
72
+ ImageInput,
73
+ PILImageResampling,
74
+ get_image_size,
75
+ infer_channel_dimension_format,
76
+ is_valid_image,
77
+ make_list_of_images,
78
+ to_numpy_array,
79
+ valid_images,
80
+ )
81
+
82
+ logger = logging.get_logger(__name__)
83
+
84
+
85
+ def round_by_factor(number: int, factor: int) -> int:
86
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
87
+ return round(number / factor) * factor
88
+
89
+
90
+ def ceil_by_factor(number: int, factor: int) -> int:
91
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
92
+ return math.ceil(number / factor) * factor
93
+
94
+
95
+ def floor_by_factor(number: int, factor: int) -> int:
96
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
97
+ return math.floor(number / factor) * factor
98
+
99
+
100
+ def smart_resize(
101
+ height: int,
102
+ width: int,
103
+ factor: int = 28,
104
+ min_pixels: int = 4 * 28 * 28,
105
+ max_pixels: int = 16384 * 28 * 28,
106
+ ):
107
+ """
108
+ Rescales the image so that the following conditions are met:
109
+
110
+ 1. Both dimensions (height and width) are divisible by 'factor'.
111
+
112
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
113
+
114
+ 3. The aspect ratio of the image is maintained as closely as possible.
115
+ """
116
+ MAX_RATIO = 200
117
+ if max(height, width) / min(height, width) > MAX_RATIO:
118
+ if height > width:
119
+ new_width = max(factor, round_by_factor(width, factor))
120
+ new_height = floor_by_factor(new_width * MAX_RATIO, factor)
121
+ else:
122
+ new_height = max(factor, round_by_factor(height, factor))
123
+ new_width = floor_by_factor(new_height * MAX_RATIO, factor)
124
+
125
+ logger.info(
126
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)},\
127
+ resize to {max(new_height, new_width) / min(new_height, new_width)}"
128
+ )
129
+
130
+ height = new_height
131
+ width = new_width
132
+
133
+ h_bar = max(factor, round_by_factor(height, factor))
134
+ w_bar = max(factor, round_by_factor(width, factor))
135
+ if h_bar * w_bar > max_pixels:
136
+ beta = math.sqrt((height * width) / max_pixels)
137
+ h_bar = floor_by_factor(height / beta, factor)
138
+ w_bar = floor_by_factor(width / beta, factor)
139
+ elif h_bar * w_bar < min_pixels:
140
+ beta = math.sqrt(min_pixels / (height * width))
141
+ h_bar = ceil_by_factor(height * beta, factor)
142
+ w_bar = ceil_by_factor(width * beta, factor)
143
+
144
+ if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
145
+ raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")
146
+
147
+ return h_bar, w_bar
148
+
149
+
150
+ def is_scaled_image(image: np.ndarray) -> bool:
151
+ """
152
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
153
+ """
154
+ if image.dtype == np.uint8:
155
+ return False
156
+
157
+ # It's possible the image has pixel values in [0, 255] but is of floating type
158
+ return np.min(image) >= 0 and np.max(image) <= 1
159
+
160
+
161
+ def make_batched_images(images) -> List[List[ImageInput]]:
162
+ """
163
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
164
+
165
+ Args:
166
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
167
+ The input image.
168
+
169
+ Returns:
170
+ list: A list of images.
171
+ """
172
+ if (
173
+ isinstance(images, (list, tuple))
174
+ and isinstance(images[0], (list, tuple))
175
+ and is_valid_image(images[0][0])
176
+ ):
177
+ return [img for img_list in images for img in img_list]
178
+
179
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
180
+ return images
181
+
182
+ elif is_valid_image(images):
183
+ return [images]
184
+
185
+ raise ValueError(f"Could not make batched images from {images}")
186
+
187
+
188
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
189
+ def make_batched_videos(videos) -> List[VideoInput]:
190
+ """dummy"""
191
+ if (
192
+ isinstance(videos, (list, tuple))
193
+ and isinstance(videos[0], (list, tuple))
194
+ and is_valid_image(videos[0][0])
195
+ ):
196
+ return videos
197
+
198
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
199
+ if isinstance(videos[0], Image.Image):
200
+ return [videos]
201
+ elif len(videos[0].shape) == 4:
202
+ return [list(video) for video in videos]
203
+
204
+ elif is_valid_image(videos) and len(videos.shape) == 4:
205
+ return [list(videos)]
206
+
207
+ raise ValueError(f"Could not make batched video from {videos}")
208
+
209
+
210
+ class Ernie_45T_VLImageProcessor(BaseImageProcessor):
211
+ r"""
212
+ Constructs a adaptive image processor that dynamically resizes images based on the original images.
213
+
214
+ Args:
215
+ do_resize (`bool`, *optional*, defaults to `True`):
216
+ Whether to resize the image's (height, width) dimensions.
217
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
218
+ Resampling filter to use when resizing the image.
219
+ do_rescale (`bool`, *optional*, defaults to `True`):
220
+ Whether to rescale the image by the specified scale `rescale_factor`.
221
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
222
+ Scale factor to use if rescaling the image.
223
+ do_normalize (`bool`, *optional*, defaults to `True`):
224
+ Whether to normalize the image.
225
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
226
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
227
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
228
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel
229
+ in the image.
230
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
231
+ Whether to convert the image to RGB.
232
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
233
+ The min pixels of the image to resize the image.
234
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
235
+ The max pixels of the image to resize the image.
236
+ patch_size (`int`, *optional*, defaults to 14):
237
+ The spacial patch size of the vision encoder.
238
+ temporal_conv_size (`int`, *optional*, defaults to 2):
239
+ The temporal conv size in resampler.
240
+ merge_size (`int`, *optional*, defaults to 2):
241
+ The merge size of the vision encoder to llm encoder.
242
+ """
243
+
244
+ model_input_names = [
245
+ "pixel_values",
246
+ "image_grid_thw",
247
+ "pixel_values_videos",
248
+ "video_grid_thw",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ do_resize: bool = True,
254
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
255
+ do_rescale: bool = True,
256
+ rescale_factor: Union[float, List[float]] = 1 / 255,
257
+ do_normalize: bool = True,
258
+ image_mean: Optional[Union[float, List[float]]] = None,
259
+ image_std: Optional[Union[float, List[float]]] = None,
260
+ do_convert_rgb: bool = True,
261
+ min_pixels: int = 56 * 56,
262
+ max_pixels: int = 28 * 28 * 1280,
263
+ patch_size: int = 14,
264
+ temporal_conv_size: int = 2,
265
+ merge_size: int = 2,
266
+ **kwargs,
267
+ ) -> None:
268
+ """init"""
269
+ super().__init__(**kwargs)
270
+ self.do_resize = do_resize
271
+ self.resample = resample
272
+ self.do_rescale = do_rescale
273
+ self.rescale_factor = rescale_factor
274
+ self.do_normalize = do_normalize
275
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
276
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
277
+ self.min_pixels = min_pixels
278
+ self.max_pixels = max_pixels
279
+ self.patch_size = patch_size
280
+ self.temporal_conv_size = temporal_conv_size
281
+ self.merge_size = merge_size
282
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
283
+ self.do_convert_rgb = do_convert_rgb
284
+
285
+ def set_pixels(self, min_pixels=None, max_pixels=None, msg=""):
286
+ """set_pixels"""
287
+ if min_pixels is not None:
288
+ assert (
289
+ isinstance(min_pixels, int) and min_pixels >= 0
290
+ ), "min_pixels must be positive int"
291
+ logger.info(
292
+ f"{msg} Ernie_45T_VLImageProcessor set min_pixels = {min_pixels}"
293
+ )
294
+ self.min_pixels = min_pixels
295
+ self.size["min_pixels"] = int(min_pixels)
296
+ if max_pixels is not None:
297
+ assert (
298
+ isinstance(max_pixels, int) and max_pixels > 0
299
+ ), "max_pixels must be positive int"
300
+ logger.info(
301
+ f"{msg} Ernie_45T_VLImageProcessor set max_pixels = {max_pixels}"
302
+ )
303
+ self.max_pixels = max_pixels
304
+ self.size["max_pixels"] = int(max_pixels)
305
+
306
+ def get_smarted_resize(self, height, width, min_pixels=None, max_pixels=None):
307
+ """dummy"""
308
+ actual_min_pixels = min_pixels if min_pixels is not None else self.min_pixels
309
+ actual_max_pixels = max_pixels if max_pixels is not None else self.max_pixels
310
+ resized_height, resized_width = smart_resize(
311
+ height,
312
+ width,
313
+ factor=self.patch_size * self.merge_size,
314
+ min_pixels=actual_min_pixels,
315
+ max_pixels=actual_max_pixels,
316
+ )
317
+ return (resized_height, resized_width), (
318
+ resized_height // self.patch_size,
319
+ resized_width // self.patch_size,
320
+ )
321
+
322
+ def _preprocess(
323
+ self,
324
+ images: Union[ImageInput, VideoInput],
325
+ do_resize: bool = True,
326
+ resample: PILImageResampling = None,
327
+ do_rescale: bool = True,
328
+ rescale_factor: float = 1 / 255,
329
+ do_normalize: bool = True,
330
+ image_mean: Optional[Union[float, List[float]]] = None,
331
+ image_std: Optional[Union[float, List[float]]] = None,
332
+ do_convert_rgb: bool = False,
333
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
334
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
335
+ predetermined_grid_thw=None,
336
+ ):
337
+ """
338
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
339
+
340
+ Args:
341
+ images (`ImageInput` or `VideoInput`):
342
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255.
343
+ If pixel values range from 0 to 1, set `do_rescale=False`.
344
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
345
+ Whether to resize the image.
346
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
347
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
348
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
349
+ Whether to rescale the image.
350
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
351
+ Scale factor to use if rescaling the image.
352
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
353
+ Whether to normalize the image.
354
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
355
+ Mean to use if normalizing the image.
356
+ Can be a float or a list of floats corresponding to the number of channels in the image.
357
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
358
+ Standard deviation to use if normalizing the image.
359
+ Can be a float or a list of floats corresponding to the number of channels in the image.
360
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
361
+ Whether to convert the image to RGB.
362
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
363
+ The channel dimension format for the output image. Can be one of:
364
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
365
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
366
+ - Unset: Use the channel dimension format of the input image.
367
+ input_data_format (`ChannelDimension` or `str`, *optional*):
368
+ The channel dimension format for the input image. Can be one of:
369
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
370
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
371
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
372
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
373
+ """
374
+ images = make_list_of_images(images)
375
+
376
+ if do_convert_rgb:
377
+ images = [convert_to_rgb(image) for image in images]
378
+
379
+ # All transformations expect numpy arrays.
380
+ images = [to_numpy_array(image) for image in images]
381
+
382
+ if is_scaled_image(images[0]) and do_rescale:
383
+ logger.warning_once(
384
+ "It looks like you are trying to rescale already rescaled images. If the input"
385
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
386
+ )
387
+ if input_data_format is None:
388
+ # We assume that all images have the same channel dimension format.
389
+ input_data_format = infer_channel_dimension_format(images[0])
390
+
391
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
392
+ resized_height, resized_width = height, width
393
+ processed_images = []
394
+
395
+ if predetermined_grid_thw is not None:
396
+ assert len(predetermined_grid_thw) == len(
397
+ images
398
+ ), f"len(predetermined_grid_thw) {len(predetermined_grid_thw)} == len(images) {len(images)}"
399
+
400
+ for img_idx, image in enumerate(images):
401
+ if do_resize:
402
+ if predetermined_grid_thw is not None:
403
+ (resized_height, resized_width) = predetermined_grid_thw[img_idx]
404
+ resized_height *= self.patch_size
405
+ resized_width *= self.patch_size
406
+ else:
407
+ resized_height, resized_width = smart_resize(
408
+ height,
409
+ width,
410
+ factor=self.patch_size * self.merge_size,
411
+ min_pixels=self.min_pixels,
412
+ max_pixels=self.max_pixels,
413
+ )
414
+
415
+ image = resize(
416
+ image,
417
+ size=(resized_height, resized_width),
418
+ resample=resample,
419
+ data_format=input_data_format,
420
+ )
421
+ if do_rescale:
422
+ image = rescale(
423
+ image, scale=rescale_factor, data_format=input_data_format
424
+ )
425
+
426
+ if do_normalize:
427
+ image = normalize(
428
+ image=image,
429
+ mean=image_mean,
430
+ std=image_std,
431
+ data_format=input_data_format,
432
+ )
433
+
434
+ image = to_channel_dimension_format(
435
+ image, data_format, input_channel_dim=input_data_format
436
+ ) # [C, H, W]
437
+
438
+ processed_images.append(image)
439
+ patches = np.array(processed_images)
440
+ if data_format == ChannelDimension.LAST:
441
+ patches = patches.transpose([0, 3, 1, 2])
442
+
443
+ channel = patches.shape[1] # [time, C, H, W]
444
+ grid_t = patches.shape[0]
445
+ grid_h, grid_w = (
446
+ resized_height // self.patch_size,
447
+ resized_width // self.patch_size,
448
+ )
449
+ patches = patches.reshape(
450
+ [
451
+ grid_t,
452
+ channel,
453
+ grid_h // self.merge_size,
454
+ self.merge_size,
455
+ self.patch_size,
456
+ grid_w // self.merge_size,
457
+ self.merge_size,
458
+ self.patch_size,
459
+ ]
460
+ )
461
+ # [grid_t, grid_h/merge_size, grid_w/merge_size, merge_size, merge_size, C, psz, psz]
462
+ patches = patches.transpose([0, 2, 5, 3, 6, 1, 4, 7])
463
+
464
+ flatten_patches = patches.reshape(
465
+ [grid_t * grid_h * grid_w, channel * self.patch_size * self.patch_size]
466
+ ) # [grid_t * grid_h * grid_w, C * psz * psz]
467
+
468
+ return flatten_patches, (grid_t, grid_h, grid_w)
469
+
470
+ def preprocess(
471
+ self,
472
+ images: ImageInput,
473
+ videos: VideoInput = None,
474
+ do_resize: bool = True,
475
+ size: Optional[Union[int, List[int]]] = None,
476
+ resample: PILImageResampling = None,
477
+ do_rescale: bool = True,
478
+ rescale_factor: float = 1 / 255,
479
+ do_normalize: bool = True,
480
+ image_mean: Optional[Union[float, List[float]]] = None,
481
+ image_std: Optional[Union[float, List[float]]] = None,
482
+ do_convert_rgb: bool = False,
483
+ return_tensors: Optional[Union[str, TensorType]] = None,
484
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
485
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
486
+ predetermined_grid_thw=None,
487
+ ):
488
+ """
489
+ Args:
490
+ images (`ImageInput`):
491
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
492
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
493
+ videos (`VideoInput`):
494
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
495
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
496
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
497
+ Whether to resize the image.
498
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
499
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
500
+ the longest edge resized to keep the input aspect ratio.
501
+ resample (`int`, *optional*, defaults to `self.resample`):
502
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
503
+ has an effect if `do_resize` is set to `True`.
504
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
505
+ Whether to rescale the image.
506
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
507
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
508
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
509
+ Whether to normalize the image.
510
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
511
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
512
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
513
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
514
+ `True`.
515
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
516
+ Whether to convert the image to RGB.
517
+ return_tensors (`str` or `TensorType`, *optional*):
518
+ The type of tensors to return. Can be one of:
519
+ - Unset: Return a list of `np.ndarray`.
520
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
521
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
522
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
523
+ The channel dimension format for the output image. Can be one of:
524
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
525
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
526
+ - Unset: Use the channel dimension format of the input image.
527
+ input_data_format (`ChannelDimension` or `str`, *optional*):
528
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
529
+ from the input image. Can be one of:
530
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
531
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
532
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
533
+
534
+ """
535
+ do_resize = do_resize if do_resize is not None else self.do_resize
536
+ size = size if size is not None else self.size
537
+ resample = resample if resample is not None else self.resample
538
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
539
+ rescale_factor = (
540
+ rescale_factor if rescale_factor is not None else self.rescale_factor
541
+ )
542
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
543
+ image_mean = image_mean if image_mean is not None else self.image_mean
544
+ image_std = image_std if image_std is not None else self.image_std
545
+ do_convert_rgb = (
546
+ do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
547
+ )
548
+
549
+ if images is not None:
550
+ images = make_batched_images(images)
551
+
552
+ if images is not None and not valid_images(images):
553
+ raise ValueError(
554
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
555
+ "torch.Tensor."
556
+ )
557
+
558
+ data = {}
559
+ if images is not None:
560
+ pixel_values, vision_grid_thws = [], []
561
+ for img_idx, image in enumerate(images):
562
+ if predetermined_grid_thw is not None:
563
+ predetermined_grid_thw_one = [predetermined_grid_thw[img_idx]]
564
+ else:
565
+ predetermined_grid_thw_one = None
566
+ patches, image_grid_thw = self._preprocess(
567
+ image,
568
+ do_resize=do_resize,
569
+ resample=resample,
570
+ do_rescale=do_rescale,
571
+ rescale_factor=rescale_factor,
572
+ do_normalize=do_normalize,
573
+ image_mean=image_mean,
574
+ image_std=image_std,
575
+ data_format=data_format,
576
+ do_convert_rgb=do_convert_rgb,
577
+ input_data_format=input_data_format,
578
+ predetermined_grid_thw=predetermined_grid_thw_one,
579
+ )
580
+ pixel_values.extend(patches)
581
+ vision_grid_thws.append(image_grid_thw)
582
+ pixel_values = np.array(pixel_values)
583
+ vision_grid_thws = np.array(vision_grid_thws)
584
+ data.update(
585
+ {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
586
+ )
587
+
588
+ if videos is not None:
589
+ videos = make_batched_videos(videos)
590
+ pixel_values, vision_grid_thws = [], []
591
+ for images in videos:
592
+ patches, video_grid_thw = self._preprocess(
593
+ images,
594
+ do_resize=do_resize,
595
+ resample=resample,
596
+ do_rescale=do_rescale,
597
+ rescale_factor=rescale_factor,
598
+ do_normalize=do_normalize,
599
+ image_mean=image_mean,
600
+ image_std=image_std,
601
+ data_format=data_format,
602
+ do_convert_rgb=do_convert_rgb,
603
+ input_data_format=input_data_format,
604
+ predetermined_grid_thw=predetermined_grid_thw,
605
+ )
606
+ pixel_values.extend(patches)
607
+ vision_grid_thws.append(video_grid_thw)
608
+ pixel_values = np.array(pixel_values)
609
+ vision_grid_thws = np.array(vision_grid_thws)
610
+
611
+ data.update(
612
+ {
613
+ "pixel_values_videos": pixel_values,
614
+ "video_grid_thw": vision_grid_thws,
615
+ }
616
+ )
617
+
618
+ return BatchFeature(data=data, tensor_type=return_tensors)
619
+
620
+
621
+ class Ernie4_5_VLTokenizer(PreTrainedTokenizer):
622
+ """
623
+ Ernie4_5_VLTokenizer
624
+ """
625
+
626
+ vocab_files_names = {
627
+ "vocab_file": "tokenizer.model",
628
+ }
629
+ # Model input names expected by the tokenizer
630
+ model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
631
+ # Padding side (where to add padding tokens)
632
+ padding_side = "right"
633
+
634
+ def __init__(
635
+ self,
636
+ vocab_file,
637
+ bos_token="<s>",
638
+ cls_token="<cls>",
639
+ eos_token="</s>",
640
+ mask_token="<mask:0>",
641
+ pad_token="<pad>",
642
+ sep_token="<sep>",
643
+ unk_token="<unk>",
644
+ additional_special_tokens=None,
645
+ **kwargs,
646
+ ):
647
+ """
648
+ Initialize the Ernie4_5_VLTokenizer
649
+
650
+ Args:
651
+ vocab_file (str): Path to the tokenizer vocabulary model.
652
+ bos_token (str, optional): The beginning of sequence token. Defaults to `"<s>"`.
653
+ cls_token (str, optional): The classifier token. Defaults to `"<cls>"`.
654
+ eos_token (str, optional): The end of sequence token. Defaults to `"</s>"`.
655
+ mask_token (str, optional): The masking token. Defaults to `"<mask:0>"`.
656
+ pad_token (str, optional): The padding token. Defaults to `"<pad>"`.
657
+ sep_token (str, optional): The separation token. Defaults to `"<sep>"`.
658
+ unk_token (str, optional): The unknown tokens symbol. Defaults to `"<unk>"`.
659
+ additional_special_tokens (List[str], optional): Additional special tokens to use.
660
+ Defaults to `["<mask:1>", "<mask:7>"]`.
661
+ **kwargs (dict): Additional keyword arguments passed along to the superclass.
662
+ """
663
+
664
+ # Store vocabulary file path
665
+ self.vocab_file = vocab_file
666
+ # Initialize SentencePiece processor
667
+ self.sp_model = spm.SentencePieceProcessor()
668
+ # Load the vocabulary model
669
+ self.sp_model.Load(vocab_file)
670
+
671
+ # Set default additional special tokens if none provided
672
+ if additional_special_tokens is None:
673
+ additional_special_tokens = ["<mask:1>", "<mask:7>"]
674
+ super().__init__(
675
+ bos_token=bos_token,
676
+ cls_token=cls_token,
677
+ eos_token=eos_token,
678
+ mask_token=mask_token,
679
+ pad_token=pad_token,
680
+ sep_token=sep_token,
681
+ unk_token=unk_token,
682
+ additional_special_tokens=additional_special_tokens,
683
+ **kwargs,
684
+ )
685
+
686
+ @property
687
+ def space_token(self):
688
+ """Return the space token"""
689
+ return "<mask:1>"
690
+
691
+ @property
692
+ def space_token_id(self):
693
+ """Return the ID of the space token"""
694
+ return self.sp_model.piece_to_id("<mask:1>")
695
+
696
+ @property
697
+ def gend_token(self):
698
+ """Return the gender token"""
699
+ return "<mask:7>"
700
+
701
+ @property
702
+ def gend_token_id(self):
703
+ """Return the ID of the gender token"""
704
+ return self.sp_model.piece_to_id("<mask:7>")
705
+
706
+ @property
707
+ def im_start_id(self):
708
+ """Return the ID of the image start token"""
709
+ return self.sp_model.piece_to_id("<|im_start|>")
710
+
711
+ @property
712
+ def im_end_id(self):
713
+ """Return the ID of the image end token"""
714
+ return self.sp_model.piece_to_id("<|im_end|>")
715
+
716
+ @property
717
+ def vocab_size(self):
718
+ """Return the size of the vocabulary"""
719
+ return self.sp_model.vocab_size()
720
+
721
+ def get_vocab(self):
722
+ """Return the vocabulary as a dictionary mapping tokens to IDs"""
723
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
724
+ vocab.update(self.added_tokens_encoder)
725
+ return vocab
726
+
727
+ def _tokenize(self, text):
728
+ """Tokenize the input text into pieces"""
729
+ return self.sp_model.encode_as_pieces(text)
730
+
731
+ def _convert_token_to_id(self, token):
732
+ """Convert a token to its corresponding ID"""
733
+ return self.sp_model.piece_to_id(token)
734
+
735
+ def _convert_id_to_token(self, id):
736
+ """Convert an ID to its corresponding token"""
737
+ return self.sp_model.id_to_piece(id)
738
+
739
+ def convert_tokens_to_string(self, tokens):
740
+ """Convert a sequence of tokens back to a string"""
741
+ current_sub_tokens = []
742
+ out_string = ""
743
+
744
+ for token in tokens:
745
+ # Handle special tokens differently
746
+ if token in self.all_special_tokens:
747
+ out_string += self.sp_model.decode(current_sub_tokens) + token
748
+ current_sub_tokens = []
749
+ else:
750
+ current_sub_tokens.append(token)
751
+
752
+ # Add any remaining sub-tokens
753
+ out_string += self.sp_model.decode(current_sub_tokens)
754
+ return out_string
755
+
756
+ def prepare_for_model(self, *args, **kwargs):
757
+ """Prepare the tokenized inputs for the model"""
758
+ # Remove add_special_tokens if present (not supported)
759
+ if "add_special_tokens" in kwargs:
760
+ kwargs.pop("add_special_tokens")
761
+ return super().prepare_for_model(*args, **kwargs)
762
+
763
+ def save_vocabulary(
764
+ self, save_directory, filename_prefix: Optional[str] = None
765
+ ) -> Tuple[str]:
766
+ """
767
+ Save the vocabulary and special tokens file to a directory.
768
+
769
+ Args:
770
+ save_directory (`str`): The directory to save the vocabulary to
771
+ filename_prefix (`str`, optional): Prefix to add to the filename
772
+
773
+ Returns:
774
+ `Tuple(str)`: Paths to the saved files
775
+ """
776
+ if not os.path.isdir(save_directory):
777
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
778
+ return
779
+
780
+ # Construct output vocabulary file path
781
+ out_vocab_file = os.path.join(
782
+ save_directory,
783
+ (filename_prefix + "-" if filename_prefix else "")
784
+ + self.vocab_files_names["vocab_file"],
785
+ )
786
+
787
+ # Copy or create vocabulary file
788
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
789
+ out_vocab_file
790
+ ) and os.path.isfile(self.vocab_file):
791
+ copyfile(self.vocab_file, out_vocab_file)
792
+ elif not os.path.isfile(self.vocab_file):
793
+ with open(out_vocab_file, "wb") as fi:
794
+ content_spiece_model = self.sp_model.serialized_model_proto()
795
+ fi.write(content_spiece_model)
796
+
797
+ return (out_vocab_file,)
798
+
799
+ def _decode(self, *args, **kwargs):
800
+ """Decode token_id back to text"""
801
+ # Remove some parameters that aren't used
802
+ kwargs.pop("clean_up_tokenization_spaces", None)
803
+ kwargs.pop("spaces_between_special_tokens", None)
804
+
805
+ # Call parent decode method with specific parameters
806
+ return super()._decode(
807
+ *args,
808
+ **kwargs,
809
+ clean_up_tokenization_spaces=False,
810
+ spaces_between_special_tokens=False,
811
+ )
812
+
813
+ def _pad(
814
+ self,
815
+ encoded_inputs: Dict,
816
+ max_length: Optional[int] = None,
817
+ padding_strategy=PaddingStrategy.DO_NOT_PAD,
818
+ pad_to_multiple_of: Optional[int] = None,
819
+ return_attention_mask: Optional[bool] = None,
820
+ ) -> dict:
821
+ """Pad the encoded inputs to the specified length"""
822
+ if return_attention_mask is None:
823
+ return_attention_mask = "attention_mask" in self.model_input_names
824
+ if return_attention_mask:
825
+ required_input = encoded_inputs[self.model_input_names[0]]
826
+ if padding_strategy == PaddingStrategy.LONGEST:
827
+ max_length = len(required_input)
828
+
829
+ # Adjust max_length if needed for multiple of padding
830
+ if (
831
+ max_length is not None
832
+ and pad_to_multiple_of is not None
833
+ and (max_length % pad_to_multiple_of != 0)
834
+ ):
835
+ max_length = (
836
+ (max_length // pad_to_multiple_of) + 1
837
+ ) * pad_to_multiple_of
838
+
839
+ # Check if padding is needed
840
+ needs_to_be_padded = (
841
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
842
+ and len(required_input) != max_length
843
+ )
844
+
845
+ # Handle attention mask if present
846
+ if (
847
+ "attention_mask" in encoded_inputs
848
+ and encoded_inputs["attention_mask"] is not None
849
+ ):
850
+ attention_mask = encoded_inputs.pop("attention_mask")
851
+ if isinstance(attention_mask, torch.Tensor):
852
+ attention_mask = attention_mask.numpy()
853
+ elif isinstance(attention_mask, list):
854
+ attention_mask = np.array(attention_mask)
855
+ elif not isinstance(attention_mask, np.ndarray):
856
+ raise ValueError(
857
+ f"Unexpected type {type(attention_mask)} of attention_mask, "
858
+ )
859
+ else:
860
+ # Create default attention mask if none provided
861
+ attention_mask = np.tril(
862
+ np.ones((len(required_input), len(required_input)), dtype=np.int64)
863
+ )
864
+ attention_mask = np.expand_dims(attention_mask, axis=0)
865
+
866
+ # Perform padding if needed
867
+ if needs_to_be_padded:
868
+ difference = max_length - len(required_input)
869
+ if self.padding_side == "right":
870
+ if attention_mask.ndim == 1:
871
+ pad_width = [(0, difference)]
872
+ else:
873
+ pad_width = [(0, 0), (0, difference), (0, difference)]
874
+ elif self.padding_side == "left":
875
+ if attention_mask.ndim == 1:
876
+ pad_width = [(difference, 0)]
877
+ else:
878
+ pad_width = [(0, 0), (difference, 0), (difference, 0)]
879
+ else:
880
+ raise ValueError(
881
+ "Invalid padding strategy:" + str(self.padding_side)
882
+ )
883
+
884
+ attention_mask = np.pad(
885
+ attention_mask,
886
+ pad_width=pad_width,
887
+ mode="constant",
888
+ constant_values=0,
889
+ )
890
+
891
+ # Call parent padding method
892
+ encoded_inputs = super()._pad(
893
+ encoded_inputs,
894
+ max_length,
895
+ padding_strategy=padding_strategy,
896
+ pad_to_multiple_of=pad_to_multiple_of,
897
+ return_attention_mask=False,
898
+ )
899
+
900
+ # Add attention mask back if needed
901
+ if return_attention_mask:
902
+ encoded_inputs["attention_mask"] = attention_mask.tolist()
903
+
904
+ return encoded_inputs
905
+
906
+
907
+ RAW_VIDEO_DIR = "./download_tmp/raw_video/"
908
+ RAW_IMAGE_DIR = "./download_tmp/raw_images/"
909
+ EXTRACTED_FRAME_DIR = "./download_tmp/extracted_frames/"
910
+ TMP_DIR = "./download_tmp/upload_tmp/"
911
+
912
+ FONT_PATH = os.path.join(Path(__file__).parent.absolute(), "Roboto-Regular.ttf")
913
+
914
+
915
+ def is_gif(data: bytes) -> bool:
916
+ """
917
+ check if a bytes is a gif based on the magic head
918
+ """
919
+ return data[:6] in (b"GIF87a", b"GIF89a")
920
+
921
+
922
+ class VideoReaderWrapper(decord.VideoReader):
923
+ """
924
+ Solving memory leak bug
925
+
926
+ https://github.com/dmlc/decord/issues/208
927
+ """
928
+
929
+ def __init__(self, video_path, *args, **kwargs):
930
+ with ntf(delete=True, suffix=".gif") as gif_file:
931
+ gif_input = None
932
+ self.original_file = None
933
+ if isinstance(video_path, str):
934
+ self.original_file = video_path
935
+ if video_path.lower().endswith(".gif"):
936
+ gif_input = video_path
937
+ elif isinstance(video_path, bytes):
938
+ if is_gif(video_path):
939
+ gif_file.write(video_path)
940
+ gif_input = gif_file.name
941
+ elif isinstance(video_path, io.BytesIO):
942
+ video_path.seek(0)
943
+ tmp_bytes = video_path.read()
944
+ video_path.seek(0)
945
+ if is_gif(tmp_bytes):
946
+ gif_file.write(tmp_bytes)
947
+ gif_input = gif_file.name
948
+
949
+ if gif_input is not None:
950
+ clip = mp.VideoFileClip(gif_input)
951
+ mp4_file = ntf(delete=False, suffix=".mp4")
952
+ clip.write_videofile(mp4_file.name, verbose=False, logger=None)
953
+ clip.close()
954
+ video_path = mp4_file.name
955
+ self.original_file = video_path
956
+
957
+ super().__init__(video_path, *args, **kwargs)
958
+ self.seek(0)
959
+
960
+ def __getitem__(self, key):
961
+ frames = super().__getitem__(key)
962
+ self.seek(0)
963
+ return frames
964
+
965
+ def __del__(self):
966
+ if self.original_file and os.path.exists(self.original_file):
967
+ os.remove(self.original_file)
968
+
969
+
970
+ def get_filename(url=None):
971
+ """
972
+ Get Filename
973
+ """
974
+ if url is None:
975
+ return str(uuid.uuid4()).replace("-", "")
976
+ t = datetime.datetime.now()
977
+ if not isinstance(url, bytes):
978
+ url = url.encode("utf-8")
979
+
980
+ md5_hash = hashlib.md5(url).hexdigest()
981
+ pid = os.getpid()
982
+ tid = threading.get_ident()
983
+
984
+ # Remove the suffix to prevent save-jpg from reporting errors
985
+ image_filname = f"{t.year}-{t.month:02d}-{t.day:02d}-{pid}-{tid}-{md5_hash}"
986
+ return image_filname
987
+
988
+
989
+ def file_download(url, download_dir, save_to_disk=False, retry=0, retry_interval=3):
990
+ """
991
+ Description: Download url, if url is PIL, return directly
992
+ Args:
993
+ url(str, PIL): http/local path/io.Bytes, note that io.Bytes is the image byte stream
994
+ download_path: when save_to_disk=True, return the saved address
995
+ save_to_disk: whether to save in the local path
996
+ """
997
+
998
+ if isinstance(url, Image.Image):
999
+ return url
1000
+ elif isinstance(url, VideoReaderWrapper):
1001
+ return url
1002
+ elif url.startswith("http"):
1003
+ response = requests.get(url)
1004
+ bytes_data = response.content
1005
+ elif os.path.isfile(url):
1006
+ if save_to_disk:
1007
+ return url
1008
+ bytes_data = open(url, "rb").read()
1009
+ else:
1010
+ bytes_data = base64.b64decode(url)
1011
+ if not save_to_disk:
1012
+ return bytes_data
1013
+
1014
+ download_path = os.path.join(download_dir, get_filename(url))
1015
+ Path(download_path).parent.mkdir(parents=True, exist_ok=True)
1016
+ with open(download_path, "wb") as f:
1017
+ f.write(bytes_data)
1018
+ return download_path
1019
+
1020
+
1021
+ def get_downloadable(
1022
+ url, download_dir=RAW_VIDEO_DIR, save_to_disk=False, retry=0, retry_interval=3
1023
+ ):
1024
+ """download video and store it in the disk
1025
+
1026
+ return downloaded **path** if save_to_disk is set to true
1027
+ return downloaded **bytes** if save_to_disk is set to false
1028
+ """
1029
+
1030
+ if not os.path.exists(download_dir):
1031
+ os.makedirs(download_dir)
1032
+ downloaded_path = file_download(
1033
+ url,
1034
+ download_dir,
1035
+ save_to_disk=save_to_disk,
1036
+ retry=retry,
1037
+ retry_interval=retry_interval,
1038
+ )
1039
+ return downloaded_path
1040
+
1041
+
1042
+ def get_downloadable_image(
1043
+ download_path, need_exif_info, retry_max_time=0, retry_interval=3
1044
+ ):
1045
+ """
1046
+ Get downloadable with exif info and image processing
1047
+ """
1048
+
1049
+ def get_image_exif(image):
1050
+ exif_data = image._getexif()
1051
+ exif_info = {}
1052
+ if exif_data is not None:
1053
+ for tag, value in exif_data.items():
1054
+ tag_name = TAGS.get(tag, tag)
1055
+ exif_info[tag_name] = value.strip()
1056
+ return exif_info
1057
+
1058
+ def has_transparent_background(img):
1059
+ """has_transparent_background"""
1060
+ if img.mode in ("RGBA", "LA") or (
1061
+ img.mode == "P" and "transparency" in img.info
1062
+ ):
1063
+ # Check for any pixel with alpha channel less than 255 (fully opaque)
1064
+ alpha = img.convert("RGBA").split()[-1]
1065
+ if alpha.getextrema()[0] < 255:
1066
+ return True
1067
+ return False
1068
+
1069
+ def add_white_background(img):
1070
+ """
1071
+ Add a white background to a transparent background image
1072
+ """
1073
+ if img.mode != "RGBA":
1074
+ img = img.convert("RGBA")
1075
+ # Create an image with a white background and the same size as the original image
1076
+ img_white_background = Image.new("RGBA", img.size, (255, 255, 255))
1077
+
1078
+ # Paste the original image onto a white background
1079
+ img_white_background.paste(img, (0, 0), img)
1080
+
1081
+ return img_white_background
1082
+
1083
+ def change_I16_to_L(img):
1084
+ """
1085
+ Convert image from I;16 mode to L mode
1086
+ """
1087
+ # Since the point function in I mode only supports addition, subtraction, and multiplication,
1088
+ # the following * (1 / 256) cannot be changed to division.
1089
+ return img.point(lambda i: i * (1 / 256)).convert("L")
1090
+
1091
+ image = get_downloadable(
1092
+ download_path,
1093
+ save_to_disk=False,
1094
+ retry=retry_max_time,
1095
+ retry_interval=retry_interval,
1096
+ )
1097
+ if isinstance(image, Image.Image):
1098
+ pil_image = image
1099
+ else:
1100
+ pil_image = Image.open(io.BytesIO(image))
1101
+ if need_exif_info:
1102
+ try:
1103
+ exif_info = get_image_exif(pil_image)
1104
+ except Exception as why:
1105
+ exif_info = {}
1106
+ else:
1107
+ exif_info = {}
1108
+
1109
+ try:
1110
+ if pil_image.mode == "I;16":
1111
+ pil_image = change_I16_to_L(pil_image)
1112
+ if has_transparent_background(pil_image):
1113
+ pil_image = add_white_background(pil_image)
1114
+ except Exception as e:
1115
+ pass
1116
+
1117
+ return pil_image.convert("RGB"), exif_info
1118
+
1119
+
1120
+ def read_video_decord(video_path, save_to_disk):
1121
+ """get reader and meta by decord"""
1122
+ video_path = get_downloadable(video_path, save_to_disk=save_to_disk)
1123
+ if isinstance(video_path, VideoReaderWrapper):
1124
+ video_reader = video_path
1125
+ else:
1126
+ if isinstance(video_path, bytes):
1127
+ video_path = io.BytesIO(video_path)
1128
+ video_reader = VideoReaderWrapper(video_path, num_threads=1)
1129
+ vlen = len(video_reader)
1130
+ fps = video_reader.get_avg_fps()
1131
+ duration = vlen / float(fps)
1132
+
1133
+ video_meta = {"fps": fps, "duration": duration, "num_of_frame": vlen}
1134
+
1135
+ return video_reader, video_meta, video_path
1136
+
1137
+
1138
+ def get_frame_indices(
1139
+ vlen,
1140
+ target_frames=-1,
1141
+ target_fps=-1,
1142
+ frames_sample="middle",
1143
+ fix_start=None,
1144
+ input_fps=-1,
1145
+ ):
1146
+ """get_frame_indices"""
1147
+ assert frames_sample in ["rand", "middle", "leading"]
1148
+ if target_frames > 0:
1149
+ assert target_fps <= 0, "target_fps must be negative if target_frames is given."
1150
+ if target_frames > vlen:
1151
+ acc_samples = vlen
1152
+ logger.info(
1153
+ f"target_frames={target_frames} is larger than video length {vlen}, "
1154
+ f"will sample {acc_samples} frames."
1155
+ )
1156
+ else:
1157
+ acc_samples = target_frames
1158
+ logger.debug(
1159
+ f"sampling at target_frames={target_frames}, frames_sample={frames_sample}"
1160
+ )
1161
+
1162
+ # split the video into `acc_samples` intervals, and sample from each interval.
1163
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
1164
+ ranges = []
1165
+ for idx, interv in enumerate(intervals[:-1]):
1166
+ ranges.append((interv, intervals[idx + 1] - 1))
1167
+ if frames_sample == "rand":
1168
+ try:
1169
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
1170
+ except Exception as e:
1171
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
1172
+ frame_indices.sort()
1173
+ frame_indices = list(frame_indices)
1174
+ elif fix_start is not None:
1175
+ frame_indices = [x[0] + fix_start for x in ranges]
1176
+ elif frames_sample == "leading":
1177
+ frame_indices = [x[0] for x in ranges]
1178
+ elif frames_sample == "middle":
1179
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
1180
+ else:
1181
+ raise NotImplementedError
1182
+
1183
+ elif target_fps > 0:
1184
+ assert (
1185
+ target_frames <= 0
1186
+ ), "target_frames must be negative if target_fps is given."
1187
+ assert input_fps > 0, "input_fps must be provided if target_fps is given."
1188
+ logger.info(f"sampling at fps={target_fps}, frames_sample={frames_sample}")
1189
+ duration = float(vlen) / input_fps
1190
+ delta = (
1191
+ 1 / target_fps
1192
+ ) # gap between frames, this is also the clip length each frame represents
1193
+ if frames_sample == "middle":
1194
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
1195
+ elif frames_sample == "leading":
1196
+ frame_seconds = np.arange(0, duration, delta)
1197
+ if frames_sample == "rand":
1198
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
1199
+ rand_offset = np.random.rand(*(frame_seconds.shape)) - 0.5
1200
+ frame_seconds += rand_offset * delta
1201
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
1202
+ frame_indices = [e for e in frame_indices if e < vlen]
1203
+
1204
+ else:
1205
+ raise ValueError(
1206
+ "Must provide either positive target_fps or positive target_frames."
1207
+ )
1208
+
1209
+ return frame_indices
1210
+
1211
+
1212
+ def read_frames_decord(
1213
+ video_path,
1214
+ video_reader,
1215
+ video_meta,
1216
+ target_frames=-1,
1217
+ target_fps=-1,
1218
+ frames_sample="middle",
1219
+ fix_start=None,
1220
+ save_to_disk=False,
1221
+ cache_dir=EXTRACTED_FRAME_DIR,
1222
+ frame_indices=None,
1223
+ tol=10,
1224
+ ):
1225
+ """get frames by decord"""
1226
+
1227
+ if frame_indices is None:
1228
+ frame_indices = get_frame_indices(
1229
+ video_meta["num_of_frame"],
1230
+ target_frames=target_frames,
1231
+ target_fps=target_fps,
1232
+ frames_sample=frames_sample,
1233
+ fix_start=fix_start,
1234
+ input_fps=video_meta["fps"],
1235
+ )
1236
+
1237
+ frames = []
1238
+ for frame_indice_index in range(0, len(frame_indices)):
1239
+ frame_indice = frame_indices[frame_indice_index]
1240
+ try:
1241
+ frames.append(video_reader[frame_indice].asnumpy()) # (T, H, W, C)
1242
+ except Exception as e:
1243
+ logger.debug(f"encounter error when get frame: {frame_indice}, error: {e}")
1244
+ previous_counter = 1
1245
+ later_counter = 1
1246
+ previous_after_flag = True
1247
+ if frame_indice == 0 or frame_indice == len(video_reader) - 1:
1248
+ cur_tol = tol * 2
1249
+ else:
1250
+ cur_tol = tol
1251
+ while previous_counter < cur_tol or later_counter < cur_tol:
1252
+ if previous_after_flag:
1253
+ if frame_indice - previous_counter < 0:
1254
+ previous_counter += 1
1255
+ previous_after_flag = not previous_after_flag
1256
+ continue
1257
+ try:
1258
+ frames.append(
1259
+ video_reader[frame_indice - previous_counter].asnumpy()
1260
+ )
1261
+ logger.info(
1262
+ f"replace {frame_indice}-th frame with {frame_indice-previous_counter}-th frame"
1263
+ )
1264
+ frame_indices[frame_indice_index] = (
1265
+ frame_indice - previous_counter
1266
+ )
1267
+ break
1268
+ except Exception as e:
1269
+ previous_counter += 1
1270
+ else:
1271
+ if frame_indice + later_counter >= len(video_reader):
1272
+ later_counter += 1
1273
+ previous_after_flag = not previous_after_flag
1274
+ continue
1275
+ try:
1276
+ frames.append(
1277
+ video_reader[frame_indice + later_counter].asnumpy()
1278
+ )
1279
+ logger.info(
1280
+ f"replace {frame_indice}-th frame with {frame_indice+later_counter}-th frame"
1281
+ )
1282
+ frame_indices[frame_indice_index] = frame_indice + later_counter
1283
+ break
1284
+ except Exception as e:
1285
+ later_counter += 1
1286
+ previous_after_flag = not previous_after_flag
1287
+
1288
+ frames = np.stack(frames, axis=0)
1289
+ assert len(frames) == len(
1290
+ frame_indices
1291
+ ), f"len(frames): {len(frames)} != len(frame_indices): {len(frame_indices)}"
1292
+
1293
+ ret = []
1294
+
1295
+ url_sha1 = get_filename()
1296
+ for idx, frame in enumerate(frames):
1297
+ tmp = Image.fromarray(frame, "RGB")
1298
+ if save_to_disk:
1299
+ save_path = os.path.join(cache_dir, f"{url_sha1}", f"{idx}.png")
1300
+ if not os.path.exists(os.path.dirname(save_path)):
1301
+ os.makedirs(os.path.dirname(save_path))
1302
+ tmp.save(save_path)
1303
+ tmp = save_path
1304
+ ret.append(tmp)
1305
+
1306
+ time_stamps = [
1307
+ frame_idx * video_meta["duration"] / video_meta["num_of_frame"]
1308
+ for frame_idx in frame_indices
1309
+ ]
1310
+
1311
+ return ret, frame_indices, time_stamps
1312
+
1313
+
1314
+ def render_single_image_with_timestamp(
1315
+ image: Image, number: str, rate: float, font_path: str = FONT_PATH
1316
+ ):
1317
+ """
1318
+ Function: Renders a timestamp to the image of pil.image
1319
+ The timestamp size is the rate of min(width, height)
1320
+ The font color is black, the outline is white, and the outline size is 10% of the font
1321
+ Returns an Image object
1322
+ """
1323
+ draw = ImageDraw.Draw(image)
1324
+ width, height = image.size
1325
+ font_size = int(min(width, height) * rate)
1326
+ outline_size = int(font_size * 0.1)
1327
+ font = ImageFont.truetype(font_path, font_size)
1328
+ x = 0
1329
+ y = 0
1330
+
1331
+ # Draw a black timestamp with a white border
1332
+ draw.text(
1333
+ (x, y),
1334
+ number,
1335
+ font=font,
1336
+ fill=(0, 0, 0),
1337
+ stroke_width=outline_size,
1338
+ stroke_fill=(255, 255, 255),
1339
+ )
1340
+
1341
+ return image
1342
+
1343
+
1344
+ def timestamp_converting(time_stamp_in_seconds):
1345
+ """
1346
+ convert timestamp format from seconds to hr:min:sec
1347
+ """
1348
+ # get hours
1349
+ hours = 0
1350
+ while time_stamp_in_seconds >= 3600:
1351
+ hours += 1
1352
+ time_stamp_in_seconds -= 3600
1353
+ # get minutes
1354
+ mins = 0
1355
+ while time_stamp_in_seconds >= 60:
1356
+ mins += 1
1357
+ time_stamp_in_seconds -= 60
1358
+ time_hours = f"{int(hours):02d}"
1359
+ time_mins = f"{int(mins):02d}"
1360
+ time_secs = f"{time_stamp_in_seconds:05.02f}"
1361
+ fi_time_stamp = time_hours + ":" + time_mins + ":" + time_secs
1362
+
1363
+ return fi_time_stamp
1364
+
1365
+
1366
+ def render_frame_timestamp(frame, timestamp, font_rate=0.1):
1367
+ """
1368
+ Function, given a frame, render the index in order
1369
+ Logic: render the index to the upper left corner of the image
1370
+ frame: frame, PIL.Image object
1371
+ timestamp: timestamp, in seconds
1372
+ font_rate: the ratio of font size to min(wi, hei)
1373
+ """
1374
+ time_stamp = "time: " + timestamp_converting(timestamp)
1375
+ new_frame = render_single_image_with_timestamp(frame, time_stamp, font_rate)
1376
+
1377
+ return new_frame
1378
 
1379
 
1380
  IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3}
 
1811
  return list(tokenizer_input_names) + list(image_processor_input_names)
1812
 
1813
 
1814
+ __all__ = ["Ernie_45T_VLImageProcessor", "Ernie4_5_VLTokenizer", "Ernie_45T_VLProcessor"]
tokenizer_config.json CHANGED
@@ -14,7 +14,7 @@
14
  "tokenizer_class": "Ernie4_5_VLTokenizer",
15
  "auto_map": {
16
  "AutoTokenizer": [
17
- "tokenization_ernie_45t_vl.Ernie4_5_VLTokenizer",
18
  null
19
  ]
20
  },
 
14
  "tokenizer_class": "Ernie4_5_VLTokenizer",
15
  "auto_map": {
16
  "AutoTokenizer": [
17
+ "processing_ernie_45t_vl.Ernie4_5_VLTokenizer",
18
  null
19
  ]
20
  },