weege007 commited on
Commit
bac4a8b
·
verified ·
1 Parent(s): 0a5453e

Update modeling_deepseekocr.py

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +6 -7
modeling_deepseekocr.py CHANGED
@@ -353,6 +353,7 @@ class DeepseekOCRConfig(DeepseekV2Config):
353
 
354
  class DeepseekOCRModel(DeepseekV2Model):
355
  config_class = DeepseekOCRConfig
 
356
 
357
  def __init__(self, config: DeepseekV2Config):
358
  super(DeepseekOCRModel, self).__init__(config)
@@ -383,7 +384,6 @@ class DeepseekOCRModel(DeepseekV2Model):
383
  images_seq_mask: Optional[torch.FloatTensor] = None,
384
  images_spatial_crop: Optional[torch.FloatTensor] = None,
385
  return_dict: Optional[bool] = None,
386
- verbose: Optional[bool] = None,
387
  ) -> Union[Tuple, BaseModelOutputWithPast]:
388
 
389
 
@@ -433,7 +433,7 @@ class DeepseekOCRModel(DeepseekV2Model):
433
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
434
  global_features = self.projector(global_features)
435
 
436
- if verbose:
437
  print('=====================')
438
  print('BASE: ', global_features.shape)
439
  print('PATCHES: ', local_features.shape)
@@ -478,7 +478,7 @@ class DeepseekOCRModel(DeepseekV2Model):
478
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
479
  global_features = self.projector(global_features)
480
 
481
- if verbose:
482
  print('=====================')
483
  print('BASE: ', global_features.shape)
484
  print('NO PATCHES')
@@ -706,6 +706,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
706
 
707
  def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
708
  self.disable_torch_init()
 
709
 
710
  if len(output_path) > 0 :
711
  os.makedirs(output_path, exist_ok=True)
@@ -930,8 +931,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
930
  streamer=streamer,
931
  max_new_tokens=8192,
932
  no_repeat_ngram_size = 20,
933
- use_cache = True,
934
- verbose = verbose
935
  )
936
 
937
  else:
@@ -948,8 +948,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
948
  eos_token_id=tokenizer.eos_token_id,
949
  max_new_tokens=8192,
950
  no_repeat_ngram_size = 35,
951
- use_cache = True,
952
- verbose = verbose
953
  )
954
 
955
 
 
353
 
354
  class DeepseekOCRModel(DeepseekV2Model):
355
  config_class = DeepseekOCRConfig
356
+ verbose = True
357
 
358
  def __init__(self, config: DeepseekV2Config):
359
  super(DeepseekOCRModel, self).__init__(config)
 
384
  images_seq_mask: Optional[torch.FloatTensor] = None,
385
  images_spatial_crop: Optional[torch.FloatTensor] = None,
386
  return_dict: Optional[bool] = None,
 
387
  ) -> Union[Tuple, BaseModelOutputWithPast]:
388
 
389
 
 
433
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
434
  global_features = self.projector(global_features)
435
 
436
+ if self.verbose:
437
  print('=====================')
438
  print('BASE: ', global_features.shape)
439
  print('PATCHES: ', local_features.shape)
 
478
  global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1)
479
  global_features = self.projector(global_features)
480
 
481
+ if self.verbose:
482
  print('=====================')
483
  print('BASE: ', global_features.shape)
484
  print('NO PATCHES')
 
706
 
707
  def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None, verbose=True):
708
  self.disable_torch_init()
709
+ self.model.verbose = verbose
710
 
711
  if len(output_path) > 0 :
712
  os.makedirs(output_path, exist_ok=True)
 
931
  streamer=streamer,
932
  max_new_tokens=8192,
933
  no_repeat_ngram_size = 20,
934
+ use_cache = True
 
935
  )
936
 
937
  else:
 
948
  eos_token_id=tokenizer.eos_token_id,
949
  max_new_tokens=8192,
950
  no_repeat_ngram_size = 35,
951
+ use_cache = True
 
952
  )
953
 
954