add support for batch multimodal understanding
Browse files- processing_emu3.py +48 -5
processing_emu3.py
CHANGED
|
@@ -14,12 +14,14 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" Processor class for Emu3. """
|
| 16 |
|
|
|
|
| 17 |
import re
|
| 18 |
from typing import List, Optional, Sequence, Union
|
| 19 |
from functools import partial
|
| 20 |
|
| 21 |
from PIL import Image
|
| 22 |
import torch
|
|
|
|
| 23 |
from transformers.feature_extraction_utils import BatchFeature
|
| 24 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| 25 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
|
@@ -73,6 +75,7 @@ class Emu3Processor(ProcessorMixin):
|
|
| 73 |
self.vision_tokenizer = vision_tokenizer
|
| 74 |
self.prefix_template = prefix_template
|
| 75 |
self.visual_template = visual_template
|
|
|
|
| 76 |
|
| 77 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 78 |
self.const_helper = self.build_const_helper()
|
|
@@ -86,6 +89,7 @@ class Emu3Processor(ProcessorMixin):
|
|
| 86 |
mode: str = "G",
|
| 87 |
ratio: str | List[str] = "1:1",
|
| 88 |
image_area: int = 518400,
|
|
|
|
| 89 |
**kwargs,
|
| 90 |
) -> BatchFeature:
|
| 91 |
"""
|
|
@@ -106,6 +110,8 @@ class Emu3Processor(ProcessorMixin):
|
|
| 106 |
the image width-height ratio for generation
|
| 107 |
image_area (`int`, *optional*):
|
| 108 |
image area used to calcualte the generated image height and width
|
|
|
|
|
|
|
| 109 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 110 |
If set, will return tensors of a particular framework. Acceptable values are:
|
| 111 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
@@ -121,10 +127,13 @@ class Emu3Processor(ProcessorMixin):
|
|
| 121 |
if isinstance(text, str):
|
| 122 |
text = [text]
|
| 123 |
|
|
|
|
|
|
|
|
|
|
| 124 |
if not isinstance(text[0], str):
|
| 125 |
raise ValueError("`text` must be string or list of string")
|
| 126 |
|
| 127 |
-
|
| 128 |
if mode == 'G':
|
| 129 |
if image is not None:
|
| 130 |
raise ValueError("You have to specify only `text` in generation mode")
|
|
@@ -144,10 +153,7 @@ class Emu3Processor(ProcessorMixin):
|
|
| 144 |
if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
|
| 145 |
raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
|
| 146 |
|
| 147 |
-
|
| 148 |
-
image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
| 149 |
-
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
| 150 |
-
|
| 151 |
if len(text) != len(image_tokens):
|
| 152 |
raise ValueError("number of image must match number of text prompt")
|
| 153 |
|
|
@@ -254,6 +260,43 @@ class Emu3Processor(ProcessorMixin):
|
|
| 254 |
tw = int(round(w * target_ratio / spatial_scale_factor))
|
| 255 |
return th, tw
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
def build_const_helper(self):
|
| 258 |
(
|
| 259 |
img_token,
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
""" Processor class for Emu3. """
|
| 16 |
|
| 17 |
+
from math import ceil
|
| 18 |
import re
|
| 19 |
from typing import List, Optional, Sequence, Union
|
| 20 |
from functools import partial
|
| 21 |
|
| 22 |
from PIL import Image
|
| 23 |
import torch
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
from transformers.feature_extraction_utils import BatchFeature
|
| 26 |
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| 27 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
|
|
|
| 75 |
self.vision_tokenizer = vision_tokenizer
|
| 76 |
self.prefix_template = prefix_template
|
| 77 |
self.visual_template = visual_template
|
| 78 |
+
self.vis_tok_spatial_factor = 2 ** (len(self.vision_tokenizer.config.ch_mult) - 1)
|
| 79 |
|
| 80 |
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 81 |
self.const_helper = self.build_const_helper()
|
|
|
|
| 89 |
mode: str = "G",
|
| 90 |
ratio: str | List[str] = "1:1",
|
| 91 |
image_area: int = 518400,
|
| 92 |
+
padding_image: bool = False,
|
| 93 |
**kwargs,
|
| 94 |
) -> BatchFeature:
|
| 95 |
"""
|
|
|
|
| 110 |
the image width-height ratio for generation
|
| 111 |
image_area (`int`, *optional*):
|
| 112 |
image area used to calcualte the generated image height and width
|
| 113 |
+
padding_image (`bool`, *optional*):
|
| 114 |
+
whether pad images to same size for fast preprocessing if they have different sizes
|
| 115 |
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
| 116 |
If set, will return tensors of a particular framework. Acceptable values are:
|
| 117 |
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
|
|
|
| 127 |
if isinstance(text, str):
|
| 128 |
text = [text]
|
| 129 |
|
| 130 |
+
if isinstance(image, Image.Image):
|
| 131 |
+
image = [image]
|
| 132 |
+
|
| 133 |
if not isinstance(text[0], str):
|
| 134 |
raise ValueError("`text` must be string or list of string")
|
| 135 |
|
| 136 |
+
image_tokens = None
|
| 137 |
if mode == 'G':
|
| 138 |
if image is not None:
|
| 139 |
raise ValueError("You have to specify only `text` in generation mode")
|
|
|
|
| 153 |
if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
|
| 154 |
raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
|
| 155 |
|
| 156 |
+
image_tokens = self.tokenize_image(image, padding_image=padding_image)
|
|
|
|
|
|
|
|
|
|
| 157 |
if len(text) != len(image_tokens):
|
| 158 |
raise ValueError("number of image must match number of text prompt")
|
| 159 |
|
|
|
|
| 260 |
tw = int(round(w * target_ratio / spatial_scale_factor))
|
| 261 |
return th, tw
|
| 262 |
|
| 263 |
+
def tokenize_image(self, image: List[Image.Image], *, padding_image: bool = False):
|
| 264 |
+
is_all_same_size, prev_size = True, None
|
| 265 |
+
for im in image:
|
| 266 |
+
if prev_size is not None:
|
| 267 |
+
is_all_same_size &= (prev_size == im.size)
|
| 268 |
+
prev_size = im.size
|
| 269 |
+
|
| 270 |
+
if is_all_same_size:
|
| 271 |
+
image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
|
| 272 |
+
image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
| 273 |
+
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
| 274 |
+
elif padding_image:
|
| 275 |
+
image_inputs = [self.image_processor(im, return_tensors="pt")["pixel_values"] for im in image]
|
| 276 |
+
image_shapes = [im.shape[2:] for im in image_inputs]
|
| 277 |
+
max_shape = (
|
| 278 |
+
max([im_shape[0] for im_shape in image_shapes]),
|
| 279 |
+
max([im_shape[1] for im_shape in image_shapes]),
|
| 280 |
+
)
|
| 281 |
+
image_inputs = [
|
| 282 |
+
F.pad(im_inp, (0, max_shape[1] - im_shape[1], 0, max_shape[0] - im_shape[0]))
|
| 283 |
+
for im_inp, im_shape in zip(image_inputs, image_shapes)
|
| 284 |
+
]
|
| 285 |
+
image_inputs = torch.cat(image_inputs, dim=0).to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
| 286 |
+
image_tokens = self.vision_tokenizer.encode(image_inputs)
|
| 287 |
+
image_tokens = [
|
| 288 |
+
im_tok[:ceil(im_shape[0] / self.vis_tok_spatial_factor), :ceil(im_shape[1] / self.vis_tok_spatial_factor)]
|
| 289 |
+
for im_tok, im_shape in zip(image_tokens, image_shapes)
|
| 290 |
+
]
|
| 291 |
+
else:
|
| 292 |
+
image_tokens = []
|
| 293 |
+
for im in image:
|
| 294 |
+
image_input = self.image_processor(im, return_tensors="pt")["pixel_values"]
|
| 295 |
+
image_input = image_input.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
|
| 296 |
+
image_tokens.append(self.vision_tokenizer.encode(image_input).squeeze(0))
|
| 297 |
+
|
| 298 |
+
return image_tokens
|
| 299 |
+
|
| 300 |
def build_const_helper(self):
|
| 301 |
(
|
| 302 |
img_token,
|