add streamer and verbose params and image_file support io.ByteIO,PIL.Image.Image for infer method

#23
by weege007 - opened
Files changed (1) hide show
  1. modeling_deepseekocr.py +22 -14
modeling_deepseekocr.py CHANGED
@@ -27,7 +27,9 @@ import time
27
  def load_image(image_path):
28
 
29
  try:
30
- image = Image.open(image_path)
 
 
31
 
32
  corrected_image = ImageOps.exif_transpose(image)
33
 
@@ -353,6 +355,7 @@ class DeepseekOCRConfig(DeepseekV2Config):
353
 
354
  class DeepseekOCRModel(DeepseekV2Model):
355
  config_class = DeepseekOCRConfig
 
356
 
357
  def __init__(self, config: DeepseekV2Config):
358
  super(DeepseekOCRModel, self).__init__(config)
@@ -432,10 +435,11 @@ class DeepseekOCRModel(DeepseekV2Model):
432
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
433
  global_features = self.projector(global_features)
434
 
435
- print('=====================')
436
- print('BASE: ', global_features.shape)
437
- print('PATCHES: ', local_features.shape)
438
- print('=====================')
 
439
 
440
  _, hw, n_dim = global_features.shape
441
  h = w = int(hw ** 0.5)
@@ -475,10 +479,12 @@ class DeepseekOCRModel(DeepseekV2Model):
475
  global_features_2 = vision_model(image_ori, global_features_1)
476
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
477
  global_features = self.projector(global_features)
478
- print('=====================')
479
- print('BASE: ', global_features.shape)
480
- print('NO PATCHES')
481
- print('=====================')
 
 
482
  _, hw, n_dim = global_features.shape
483
  h = w = int(hw ** 0.5)
484
 
@@ -700,11 +706,13 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
700
 
701
 
702
 
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
704
  self.disable_torch_init()
 
705
 
706
- os.makedirs(output_path, exist_ok=True)
707
- os.makedirs(f'{output_path}/images', exist_ok=True)
 
708
 
709
  if prompt and image_file:
710
  conversation = [
@@ -716,7 +724,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
716
  # "content": "<image>\nFree OCR. ",
717
  # "content": "<image>\nParse the figure. ",
718
  # "content": "<image>\nExtract the text in the image. ",
719
- "images": [f'{image_file}'],
720
  },
721
  {"role": "<|Assistant|>", "content": ""},
722
  ]
@@ -910,7 +918,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
910
 
911
 
912
  if not eval_mode:
913
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
  with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(
 
27
  def load_image(image_path):
28
 
29
  try:
30
+ image = image_path
31
+ if not isinstance(image_path, Image.Image):
32
+ image = Image.open(image_path)
33
 
34
  corrected_image = ImageOps.exif_transpose(image)
35
 
 
355
 
356
  class DeepseekOCRModel(DeepseekV2Model):
357
  config_class = DeepseekOCRConfig
358
+ verbose = True
359
 
360
  def __init__(self, config: DeepseekV2Config):
361
  super(DeepseekOCRModel, self).__init__(config)
 
435
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
436
  global_features = self.projector(global_features)
437
 
438
+ if self.verbose:
439
+ print('=====================')
440
+ print('BASE: ', global_features.shape)
441
+ print('PATCHES: ', local_features.shape)
442
+ print('=====================')
443
 
444
  _, hw, n_dim = global_features.shape
445
  h = w = int(hw ** 0.5)
 
479
  global_features_2 = vision_model(image_ori, global_features_1)
480
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
481
  global_features = self.projector(global_features)
482
+
483
+ if self.verbose:
484
+ print('=====================')
485
+ print('BASE: ', global_features.shape)
486
+ print('NO PATCHES')
487
+ print('=====================')
488
  _, hw, n_dim = global_features.shape
489
  h = w = int(hw ** 0.5)
490
 
 
706
 
707
 
708
 
709
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
710
  self.disable_torch_init()
711
+ self.model.verbose = verbose
712
 
713
+ if len(output_path) > 0 :
714
+ os.makedirs(output_path, exist_ok=True)
715
+ os.makedirs(f'{output_path}/images', exist_ok=True)
716
 
717
  if prompt and image_file:
718
  conversation = [
 
724
  # "content": "<image>\nFree OCR. ",
725
  # "content": "<image>\nParse the figure. ",
726
  # "content": "<image>\nExtract the text in the image. ",
727
+ "images": [image_file] if isinstance(image_file, (BytesIO, Image.Image)) else [f'{image_file}'],
728
  },
729
  {"role": "<|Assistant|>", "content": ""},
730
  ]
 
918
 
919
 
920
  if not eval_mode:
921
+ streamer = streamer or NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
922
  with torch.autocast("cuda", dtype=torch.bfloat16):
923
  with torch.no_grad():
924
  output_ids = self.generate(