weege007 commited on
Commit
3e40431
·
verified ·
1 Parent(s): 5951289

add streamer for infer method when not eval mode

Browse files
Files changed (1) hide show
  1. modeling_deepseekocr.py +2 -2
modeling_deepseekocr.py CHANGED
@@ -700,7 +700,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
700
 
701
 
702
 
703
- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
704
  self.disable_torch_init()
705
 
706
  os.makedirs(output_path, exist_ok=True)
@@ -910,7 +910,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
910
 
911
 
912
  if not eval_mode:
913
- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
  with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(
 
700
 
701
 
702
 
703
+ def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False, streamer=None):
704
  self.disable_torch_init()
705
 
706
  os.makedirs(output_path, exist_ok=True)
 
910
 
911
 
912
  if not eval_mode:
913
+ streamer = streamer or NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
  with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(